from .data import Data, \ EdgeType import torch from dataclasses import dataclass from .weights import init_glorot import types from typing import List, \ Dict, \ Callable, \ Tuple from .util import _sparse_coo_tensor, \ _sparse_diag_cat, \ _mm from .normalize import norm_adj_mat_one_node_type, \ norm_adj_mat_two_node_types from .dropout import dropout @dataclass class TrainingBatch(object): vertex_type_row: int vertex_type_column: int relation_type_index: int edges: torch.Tensor target_values: torch.Tensor def _per_layer_required_vertices(data: Data, batch: TrainingBatch, num_layers: int) -> List[List[EdgeType]]: Q = [ ( batch.vertex_type_row, batch.edges[:, 0] ), ( batch.vertex_type_column, batch.edges[:, 1] ) ] print('Q:', Q) res = [] for _ in range(num_layers): R = [] required_rows = [ [] for _ in range(len(data.vertex_types)) ] for vertex_type, vertices in Q: for et in data.edge_types.values(): if et.vertex_type_row == vertex_type: required_rows[vertex_type].append(vertices) indices = et.total_connectivity.indices() mask = torch.zeros(et.total_connectivity.shape[0]) mask[vertices] = 1 mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] R.append((et.vertex_type_column, indices[1, mask])) else: pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) required_rows = [ torch.unique(torch.cat(x)) \ if len(x) > 0 \ else None \ for x in required_rows ] res.append(required_rows) Q = R return res class Model(torch.nn.Module): def __init__(self, data: Data, layer_dimensions: List[int], keep_prob: float, conv_activation: Callable[[torch.Tensor], torch.Tensor], dec_activation: Callable[[torch.Tensor], torch.Tensor], **kwargs) -> None: super().__init__(**kwargs) if not isinstance(data, Data): raise TypeError('data must be an instance of Data') if not callable(conv_activation): raise TypeError('conv_activation must be callable') if not callable(dec_activation): raise TypeError('dec_activation must be callable') self.data = data self.layer_dimensions = list(layer_dimensions) self.keep_prob = float(keep_prob) self.conv_activation = conv_activation self.dec_activation = dec_activation self.adj_matrices = None self.conv_weights = None self.dec_weights = None self.build() def build(self) -> None: self.adj_matrices = torch.nn.ParameterDict() for _, et in self.data.edge_types.items(): adj_matrices = [ norm_adj_mat_one_node_type(x) \ if et.vertex_type_row == et.vertex_type_column \ else norm_adj_mat_two_node_types(x) \ for x in et.adjacency_matrices ] adj_matrices = _sparse_diag_cat(et.adjacency_matrices) print('adj_matrices:', adj_matrices) self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \ torch.nn.Parameter(adj_matrices, requires_grad=False) self.conv_weights = torch.nn.ParameterDict() for i in range(len(self.layer_dimensions) - 1): in_dimension = self.layer_dimensions[i] out_dimension = self.layer_dimensions[i + 1] for _, et in self.data.edge_types.items(): weights = [ init_glorot(in_dimension, out_dimension) \ for _ in range(len(et.adjacency_matrices)) ] weights = torch.cat(weights, dim=1) self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ torch.nn.Parameter(weights) self.dec_weights = torch.nn.ParameterDict() for _, et in self.data.edge_types.items(): global_interaction, local_variation = \ et.decoder_factory(self.layer_dimensions[-1], len(et.adjacency_matrices)) self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \ torch.nn.Parameter(global_interaction) for i in range(len(local_variation)): self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \ torch.nn.Parameter(local_variation[i]) def convolve(self, in_layer_repr: List[torch.Tensor]) -> \ List[torch.Tensor]: cur_layer_repr = in_layer_repr for i in range(len(self.layer_dimensions) - 1): next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] for _, et in self.data.edge_types.items(): vt_row, vt_col = et.vertex_type_row, et.vertex_type_column adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)] conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)] num_relation_types = len(et.adjacency_matrices) x = cur_layer_repr[vt_col] if self.keep_prob != 1: x = dropout(x, self.keep_prob) # print('a, Layer:', i, 'x.shape:', x.shape) x = _mm(x, conv_weights) x = torch.split(x, x.shape[1] // num_relation_types, dim=1) x = torch.cat(x) x = _mm(adj_matrices, x) x = x.view(num_relation_types, self.data.vertex_types[vt_row].count, self.layer_dimensions[i + 1]) # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) x = x.sum(dim=0) x = torch.nn.functional.normalize(x, p=2, dim=1) # x = self.rel_activation(x) # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) next_layer_repr[vt_row].append(x) next_layer_repr = [ self.conv_activation(sum(x)) \ for x in next_layer_repr ] cur_layer_repr = next_layer_repr return cur_layer_repr def decode(self, last_layer_repr: List[torch.Tensor], batch: TrainingBatch) -> torch.Tensor: vt_row = batch.vertex_type_row vt_col = batch.vertex_type_column rel_idx = batch.relation_type_index global_interaction = \ self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)] local_variation = \ self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)] in_row = last_layer_repr[vt_row] in_col = last_layer_repr[vt_col] if in_row.is_sparse or in_col.is_sparse: raise ValueError('Inputs to Model.decode() must be dense') in_row = in_row[batch.edges[:, 0]] in_col = in_col[batch.edges[:, 1]] in_row = dropout(in_row, self.keep_prob) in_col = dropout(in_col, self.keep_prob) # in_row = in_row.to_dense() # in_col = in_col.to_dense() print('in_row.is_sparse:', in_row.is_sparse) print('in_col.is_sparse:', in_col.is_sparse) x = torch.mm(in_row, local_variation) x = torch.mm(x, global_interaction) x = torch.mm(x, local_variation) x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), in_col.view(in_col.shape[0], in_col.shape[1], 1)) x = torch.flatten(x) x = self.dec_activation(x) return x def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: edges = [] cur_edges = batch.edges for _ in range(len(self.layer_dimensions) - 1): edges.append(cur_edges) key = (batch.vertex_type_row, batch.vertex_type_column) tot_conn = self.data.relation_types[key].total_connectivity cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: col = batch.vertex_type_column rows = batch.edges[:, 0] columns = batch.edges[:, 1].sum(dim=0).flatten() columns = torch.nonzero(columns) for i in range(len(self.layer_dimensions) - 1): pass # columns = # TODO: finish return None def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: col = batch.vertex_type_column batch.edges[:, 1] res = {} for _, et in self.data.edge_types.items(): sum_nonzero = _nonzero_sum(et.adjacency_matrices) res[et.vertex_type_row, et.vertex_type_column] = \ [ self.temporary_adjacency_matrix(adj_mat, batch, et.total_connectivity) \ for adj_mat in et.adjacency_matrices ] return res def forward(self, initial_repr: List[torch.Tensor], batch: TrainingBatch) -> torch.Tensor: if not isinstance(initial_repr, list): raise TypeError('initial_repr must be a list') if len(initial_repr) != len(self.data.vertex_types): raise ValueError('initial_repr must contain representations for all vertex types') if not isinstance(batch, TrainingBatch): raise TypeError('batch must be an instance of TrainingBatch') adj_matrices = self.temporary_adjacency_matrices(batch) row_vertices = initial_repr[batch.vertex_type_row] column_vertices = initial_repr[batch.vertex_type_column]