|  |  | @@ -5,11 +5,12 @@ | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | from collections import defaultdict | 
		
	
		
			
			|  |  |  | from dataclasses import dataclass | 
		
	
		
			
			|  |  |  | from dataclasses import dataclass, field | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  | from typing import List, \ | 
		
	
		
			
			|  |  |  | Dict, \ | 
		
	
		
			
			|  |  |  | Tuple | 
		
	
		
			
			|  |  |  | Tuple, \ | 
		
	
		
			
			|  |  |  | Any | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | @dataclass | 
		
	
	
		
			
				|  |  | @@ -24,7 +25,7 @@ class RelationType(object): | 
		
	
		
			
			|  |  |  | node_type_row: int | 
		
	
		
			
			|  |  |  | node_type_column: int | 
		
	
		
			
			|  |  |  | adjacency_matrix: torch.Tensor | 
		
	
		
			
			|  |  |  | is_autogenerated: bool = False | 
		
	
		
			
			|  |  |  | hints: Dict[str, Any] = field(default_factory=dict) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class Data(object): | 
		
	
	
		
			
				|  |  | @@ -81,14 +82,16 @@ class Data(object): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.relation_types[node_type_row, node_type_column].append( | 
		
	
		
			
			|  |  |  | RelationType(name, node_type_row, node_type_column, | 
		
	
		
			
			|  |  |  | adjacency_matrix, False)) | 
		
	
		
			
			|  |  |  | adjacency_matrix)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if node_type_row != node_type_column and two_way: | 
		
	
		
			
			|  |  |  | hints = { 'display': False } | 
		
	
		
			
			|  |  |  | if adjacency_matrix_backward is None: | 
		
	
		
			
			|  |  |  | adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | 
		
	
		
			
			|  |  |  | hints['symmetric'] = True | 
		
	
		
			
			|  |  |  | self.relation_types[node_type_column, node_type_row].append( | 
		
	
		
			
			|  |  |  | RelationType(name, node_type_column, node_type_row, | 
		
	
		
			
			|  |  |  | adjacency_matrix_backward, True)) | 
		
	
		
			
			|  |  |  | adjacency_matrix_backward, hints)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def __repr__(self): | 
		
	
		
			
			|  |  |  | n = len(self.node_types) | 
		
	
	
		
			
				|  |  | @@ -114,7 +117,7 @@ class Data(object): | 
		
	
		
			
			|  |  |  | self.node_types[node_type_column].name + ':\n' | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for r in self.relation_types[node_type_row, node_type_column]: | 
		
	
		
			
			|  |  |  | if r.is_autogenerated: | 
		
	
		
			
			|  |  |  | if not r.hints.get('display', True): | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | s_1 += '    - ' + r.name + '\n' | 
		
	
		
			
			|  |  |  | count += 1 | 
		
	
	
		
			
				|  |  | 
 |