from .data import Data, \ EdgeType import torch from dataclasses import dataclass from .weights import init_glorot import types from typing import List, \ Dict, \ Callable, \ Tuple from .util import _sparse_coo_tensor @dataclass class TrainingBatch(object): vertex_type_row: int vertex_type_column: int relation_type_index: int edges: torch.Tensor def _per_layer_required_rows(data: Data, batch: TrainingBatch, num_layers: int) -> List[List[EdgeType]]: Q = [ ( batch.vertex_type_row, batch.edges[:, 0] ), ( batch.vertex_type_column, batch.edges[:, 1] ) ] print('Q:', Q) res = [] for _ in range(num_layers): R = [] required_rows = [ [] for _ in range(len(data.vertex_types)) ] for vertex_type, vertices in Q: for et in data.edge_types.values(): if et.vertex_type_row == vertex_type: required_rows[vertex_type].append(vertices) indices = et.total_connectivity.indices() mask = torch.zeros(et.total_connectivity.shape[0]) mask[vertices] = 1 mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] R.append((et.vertex_type_column, indices[1, mask])) else: pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) required_rows = [ torch.unique(torch.cat(x)) \ if len(x) > 0 \ else None \ for x in required_rows ] res.append(required_rows) Q = R return res class Model(torch.nn.Module): def __init__(self, data: Data, layer_dimensions: List[int], keep_prob: float, conv_activation: Callable[[torch.Tensor], torch.Tensor], dec_activation: Callable[[torch.Tensor], torch.Tensor], **kwargs) -> None: super().__init__(**kwargs) if not isinstance(data, Data): raise TypeError('data must be an instance of Data') if not isinstance(conv_activation, types.FunctionType): raise TypeError('conv_activation must be a function') if not isinstance(dec_activation, types.FunctionType): raise TypeError('dec_activation must be a function') self.data = data self.layer_dimensions = list(layer_dimensions) self.keep_prob = float(keep_prob) self.conv_activation = conv_activation self.dec_activation = dec_activation self.conv_weights = None self.dec_weights = None self.build() def build(self) -> None: self.conv_weights = torch.nn.ParameterDict() for i in range(len(self.layer_dimensions) - 1): in_dimension = self.layer_dimensions[i] out_dimension = self.layer_dimensions[i + 1] for _, et in self.data.edge_types.items(): weight = init_glorot(in_dimension, out_dimension) self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \ torch.nn.Parameter(weight) self.dec_weights = torch.nn.ParameterDict() for _, et in self.data.edge_types.items(): global_interaction, local_variation = \ et.decoder_factory(self.layer_dimensions[-1], len(et.adjacency_matrices)) self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \ torch.nn.ParameterList([ torch.nn.Parameter(global_interaction), torch.nn.Parameter(local_variation) ]) def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: edges = [] cur_edges = batch.edges for _ in range(len(self.layer_dimensions) - 1): edges.append(cur_edges) key = (batch.vertex_type_row, batch.vertex_type_column) tot_conn = self.data.relation_types[key].total_connectivity cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: col = batch.vertex_type_column rows = batch.edges[:, 0] columns = batch.edges[:, 1].sum(dim=0).flatten() columns = torch.nonzero(columns) for i in range(len(self.layer_dimensions) - 1): pass # columns = # TODO: finish return None def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: col = batch.vertex_type_column batch.edges[:, 1] res = {} for _, et in self.data.edge_types.items(): sum_nonzero = _nonzero_sum(et.adjacency_matrices) res[et.vertex_type_row, et.vertex_type_column] = \ [ self.temporary_adjacency_matrix(adj_mat, batch, et.total_connectivity) \ for adj_mat in et.adjacency_matrices ] return res def forward(self, initial_repr: List[torch.Tensor], batch: TrainingBatch) -> torch.Tensor: if not isinstance(initial_repr, list): raise TypeError('initial_repr must be a list') if len(initial_repr) != len(self.data.vertex_types): raise ValueError('initial_repr must contain representations for all vertex types') if not isinstance(batch, TrainingBatch): raise TypeError('batch must be an instance of TrainingBatch') adj_matrices = self.temporary_adjacency_matrices(batch) row_vertices = initial_repr[batch.vertex_type_row] column_vertices = initial_repr[batch.vertex_type_column]