|  |  | @@ -8,7 +8,12 @@ from typing import List, \ | 
		
	
		
			
			|  |  |  | Dict, \ | 
		
	
		
			
			|  |  |  | Callable, \ | 
		
	
		
			
			|  |  |  | Tuple | 
		
	
		
			
			|  |  |  | from .util import _sparse_coo_tensor | 
		
	
		
			
			|  |  |  | from .util import _sparse_coo_tensor, \ | 
		
	
		
			
			|  |  |  | _sparse_diag_cat, \ | 
		
	
		
			
			|  |  |  | _mm | 
		
	
		
			
			|  |  |  | from .normalize import norm_adj_mat_one_node_type, \ | 
		
	
		
			
			|  |  |  | norm_adj_mat_two_node_types | 
		
	
		
			
			|  |  |  | from .dropout import dropout | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | @dataclass | 
		
	
	
		
			
				|  |  | @@ -19,6 +24,44 @@ class TrainingBatch(object): | 
		
	
		
			
			|  |  |  | edges: torch.Tensor | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | 
		
	
		
			
			|  |  |  | num_layers: int) -> List[List[EdgeType]]: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | Q = [ | 
		
	
		
			
			|  |  |  | ( batch.vertex_type_row, batch.edges[:, 0] ), | 
		
	
		
			
			|  |  |  | ( batch.vertex_type_column, batch.edges[:, 1] ) | 
		
	
		
			
			|  |  |  | ] | 
		
	
		
			
			|  |  |  | print('Q:', Q) | 
		
	
		
			
			|  |  |  | res = [] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for _ in range(num_layers): | 
		
	
		
			
			|  |  |  | R = [] | 
		
	
		
			
			|  |  |  | required_rows = [ [] for _ in range(len(data.vertex_types)) ] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for vertex_type, vertices in Q: | 
		
	
		
			
			|  |  |  | for et in data.edge_types.values(): | 
		
	
		
			
			|  |  |  | if et.vertex_type_row == vertex_type: | 
		
	
		
			
			|  |  |  | required_rows[vertex_type].append(vertices) | 
		
	
		
			
			|  |  |  | indices = et.total_connectivity.indices() | 
		
	
		
			
			|  |  |  | mask = torch.zeros(et.total_connectivity.shape[0]) | 
		
	
		
			
			|  |  |  | mask[vertices] = 1 | 
		
	
		
			
			|  |  |  | mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | 
		
	
		
			
			|  |  |  | R.append((et.vertex_type_column, | 
		
	
		
			
			|  |  |  | indices[1, mask])) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | required_rows = [ torch.unique(torch.cat(x)) \ | 
		
	
		
			
			|  |  |  | if len(x) > 0 \ | 
		
	
		
			
			|  |  |  | else None \ | 
		
	
		
			
			|  |  |  | for x in required_rows ] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res.append(required_rows) | 
		
	
		
			
			|  |  |  | Q = R | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return res | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | class Model(torch.nn.Module): | 
		
	
		
			
			|  |  |  | def __init__(self, data: Data, layer_dimensions: List[int], | 
		
	
		
			
			|  |  |  | keep_prob: float, | 
		
	
	
		
			
				|  |  | @@ -30,11 +73,11 @@ class Model(torch.nn.Module): | 
		
	
		
			
			|  |  |  | if not isinstance(data, Data): | 
		
	
		
			
			|  |  |  | raise TypeError('data must be an instance of Data') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(conv_activation, types.FunctionType): | 
		
	
		
			
			|  |  |  | raise TypeError('conv_activation must be a function') | 
		
	
		
			
			|  |  |  | if not callable(conv_activation): | 
		
	
		
			
			|  |  |  | raise TypeError('conv_activation must be callable') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if not isinstance(dec_activation, types.FunctionType): | 
		
	
		
			
			|  |  |  | raise TypeError('dec_activation must be a function') | 
		
	
		
			
			|  |  |  | if not callable(dec_activation): | 
		
	
		
			
			|  |  |  | raise TypeError('dec_activation must be callable') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.data = data | 
		
	
		
			
			|  |  |  | self.layer_dimensions = list(layer_dimensions) | 
		
	
	
		
			
				|  |  | @@ -42,35 +85,90 @@ class Model(torch.nn.Module): | 
		
	
		
			
			|  |  |  | self.conv_activation = conv_activation | 
		
	
		
			
			|  |  |  | self.dec_activation = dec_activation | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.adj_matrices = None | 
		
	
		
			
			|  |  |  | self.conv_weights = None | 
		
	
		
			
			|  |  |  | self.dec_weights = None | 
		
	
		
			
			|  |  |  | self.build() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def build(self) -> None: | 
		
	
		
			
			|  |  |  | self.adj_matrices = torch.nn.ParameterDict() | 
		
	
		
			
			|  |  |  | for _, et in self.data.edge_types.items(): | 
		
	
		
			
			|  |  |  | adj_matrices = [ | 
		
	
		
			
			|  |  |  | norm_adj_mat_one_node_type(x) \ | 
		
	
		
			
			|  |  |  | if et.vertex_type_row == et.vertex_type_column \ | 
		
	
		
			
			|  |  |  | else norm_adj_mat_two_node_types(x) \ | 
		
	
		
			
			|  |  |  | for x in et.adjacency_matrices | 
		
	
		
			
			|  |  |  | ] | 
		
	
		
			
			|  |  |  | adj_matrices = _sparse_diag_cat(et.adjacency_matrices) | 
		
	
		
			
			|  |  |  | print('adj_matrices:', adj_matrices) | 
		
	
		
			
			|  |  |  | self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(adj_matrices, requires_grad=False) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.conv_weights = torch.nn.ParameterDict() | 
		
	
		
			
			|  |  |  | for i in range(len(self.layer_dimensions) - 1): | 
		
	
		
			
			|  |  |  | in_dimension = self.layer_dimensions[i] | 
		
	
		
			
			|  |  |  | out_dimension = self.layer_dimensions[i + 1] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for _, et in self.data.edge_types.items(): | 
		
	
		
			
			|  |  |  | weight = init_glorot(in_dimension, out_dimension) | 
		
	
		
			
			|  |  |  | self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(weight) | 
		
	
		
			
			|  |  |  | weights = [ init_glorot(in_dimension, out_dimension) \ | 
		
	
		
			
			|  |  |  | for _ in range(len(et.adjacency_matrices)) ] | 
		
	
		
			
			|  |  |  | weights = torch.cat(weights, dim=1) | 
		
	
		
			
			|  |  |  | self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(weights) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | self.dec_weights = torch.nn.ParameterDict() | 
		
	
		
			
			|  |  |  | for _, et in self.data.edge_types.items(): | 
		
	
		
			
			|  |  |  | global_interaction, local_variation = \ | 
		
	
		
			
			|  |  |  | et.decoder_factory(self.layer_dimensions[-1], | 
		
	
		
			
			|  |  |  | len(et.adjacency_matrices)) | 
		
	
		
			
			|  |  |  | self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ | 
		
	
		
			
			|  |  |  | torch.nn.ParameterList([ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(global_interaction), | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(local_variation) | 
		
	
		
			
			|  |  |  | ]) | 
		
	
		
			
			|  |  |  | self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(global_interaction) | 
		
	
		
			
			|  |  |  | for i in range(len(local_variation)): | 
		
	
		
			
			|  |  |  | self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ | 
		
	
		
			
			|  |  |  | torch.nn.Parameter(local_variation[i]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def convolve(self, in_layer_repr: List[torch.Tensor]) -> \ | 
		
	
		
			
			|  |  |  | List[torch.Tensor]: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: | 
		
	
		
			
			|  |  |  | cur_layer_repr = in_layer_repr | 
		
	
		
			
			|  |  |  | next_layer_repr = [ None ] * len(self.data.vertex_types) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i in range(len(self.layer_dimensions) - 1): | 
		
	
		
			
			|  |  |  | for _, et in self.data.edge_types.items(): | 
		
	
		
			
			|  |  |  | vt_row, vt_col = et.vertex_type_row, et.vertex_type_column | 
		
	
		
			
			|  |  |  | adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)] | 
		
	
		
			
			|  |  |  | conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)] | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | num_relation_types = len(et.adjacency_matrices) | 
		
	
		
			
			|  |  |  | x = cur_layer_repr[vt_col] | 
		
	
		
			
			|  |  |  | if self.keep_prob != 1: | 
		
	
		
			
			|  |  |  | x = dropout(x, self.keep_prob) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('a, Layer:', i, 'x.shape:', x.shape) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | x = _mm(x, conv_weights) | 
		
	
		
			
			|  |  |  | x = torch.split(x, | 
		
	
		
			
			|  |  |  | x.shape[1] // num_relation_types, | 
		
	
		
			
			|  |  |  | dim=1) | 
		
	
		
			
			|  |  |  | x = torch.cat(x) | 
		
	
		
			
			|  |  |  | x = _mm(adj_matrices, x) | 
		
	
		
			
			|  |  |  | x = x.view(num_relation_types, | 
		
	
		
			
			|  |  |  | self.data.vertex_types[vt_row].count, | 
		
	
		
			
			|  |  |  | self.layer_dimensions[i + 1]) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('b, Layer:', i, 'x.shape:', x.shape) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | x = x.sum(dim=0) | 
		
	
		
			
			|  |  |  | x = torch.nn.functional.normalize(x, p=2, dim=1) | 
		
	
		
			
			|  |  |  | x = self.conv_activation(x) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | next_layer_repr[vt_row] = x | 
		
	
		
			
			|  |  |  | cur_layer_repr = next_layer_repr | 
		
	
		
			
			|  |  |  | return next_layer_repr | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | 
		
	
		
			
			|  |  |  | edges = [] | 
		
	
		
			
			|  |  |  | cur_edges = batch.edges | 
		
	
		
			
			|  |  |  | for _ in range(len(self.layer_dimensions) - 1): | 
		
	
	
		
			
				|  |  | 
 |