|  |  | @@ -0,0 +1,38 @@ | 
		
	
		
			
			|  |  |  | class Data(object): | 
		
	
		
			
			|  |  |  | def __init__(self): | 
		
	
		
			
			|  |  |  | self.node_types = [] | 
		
	
		
			
			|  |  |  | self.relation_types = [] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def add_node_type(self, name): | 
		
	
		
			
			|  |  |  | self.node_types.append(name) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name): | 
		
	
		
			
			|  |  |  | n = len(self.node_types) | 
		
	
		
			
			|  |  |  | if node_type_row >= n or node_type_column >= n: | 
		
	
		
			
			|  |  |  | raise ValueError('Node type index out of bounds, add node type first') | 
		
	
		
			
			|  |  |  | self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __repr__(self): | 
		
	
		
			
			|  |  |  | n = len(self.node_types) | 
		
	
		
			
			|  |  |  | if n == 0: | 
		
	
		
			
			|  |  |  | return 'Empty GNN Data' | 
		
	
		
			
			|  |  |  | s = '' | 
		
	
		
			
			|  |  |  | s += 'GNN Data with:\n' | 
		
	
		
			
			|  |  |  | s += '- ' + str(n) + ' node type(s):\n' | 
		
	
		
			
			|  |  |  | for nt in self.node_types: | 
		
	
		
			
			|  |  |  | s += '  - ' + nt + '\n' | 
		
	
		
			
			|  |  |  | if len(self.relation_types) == 0: | 
		
	
		
			
			|  |  |  | s += '- No relation types\n' | 
		
	
		
			
			|  |  |  | return s.strip() | 
		
	
		
			
			|  |  |  | s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n' | 
		
	
		
			
			|  |  |  | for i in range(n): | 
		
	
		
			
			|  |  |  | for j in range(n): | 
		
	
		
			
			|  |  |  | rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) | 
		
	
		
			
			|  |  |  | if len(rels) == 0: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | # dir = '<->' if i == j else '->' | 
		
	
		
			
			|  |  |  | dir = '--' | 
		
	
		
			
			|  |  |  | s += '  - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n' | 
		
	
		
			
			|  |  |  | for r in rels: | 
		
	
		
			
			|  |  |  | s += '    - ' + r[3] + '\n' | 
		
	
		
			
			|  |  |  | return s.strip() |