|
- from .data import Data, \
- EdgeType
- import torch
- from dataclasses import dataclass
- from .weights import init_glorot
- import types
- from typing import List, \
- Dict, \
- Callable, \
- Tuple
- from .util import _sparse_coo_tensor, \
- _sparse_diag_cat, \
- _mm
- from .normalize import norm_adj_mat_one_node_type, \
- norm_adj_mat_two_node_types
- from .dropout import dropout
-
-
- @dataclass
- class TrainingBatch(object):
- vertex_type_row: int
- vertex_type_column: int
- relation_type_index: int
- edges: torch.Tensor
- target_values: torch.Tensor
-
-
- def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
- num_layers: int) -> List[List[EdgeType]]:
-
- Q = [
- ( batch.vertex_type_row, batch.edges[:, 0] ),
- ( batch.vertex_type_column, batch.edges[:, 1] )
- ]
- print('Q:', Q)
- res = []
-
- for _ in range(num_layers):
- R = []
- required_rows = [ [] for _ in range(len(data.vertex_types)) ]
-
- for vertex_type, vertices in Q:
- for et in data.edge_types.values():
- if et.vertex_type_row == vertex_type:
- required_rows[vertex_type].append(vertices)
- indices = et.total_connectivity.indices()
- mask = torch.zeros(et.total_connectivity.shape[0])
- mask[vertices] = 1
- mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
- R.append((et.vertex_type_column,
- indices[1, mask]))
- else:
- pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
-
- required_rows = [ torch.unique(torch.cat(x)) \
- if len(x) > 0 \
- else None \
- for x in required_rows ]
-
- res.append(required_rows)
- Q = R
-
- return res
-
-
- class Model(torch.nn.Module):
- def __init__(self, data: Data, layer_dimensions: List[int],
- keep_prob: float,
- conv_activation: Callable[[torch.Tensor], torch.Tensor],
- dec_activation: Callable[[torch.Tensor], torch.Tensor],
- **kwargs) -> None:
- super().__init__(**kwargs)
-
- if not isinstance(data, Data):
- raise TypeError('data must be an instance of Data')
-
- if not callable(conv_activation):
- raise TypeError('conv_activation must be callable')
-
- if not callable(dec_activation):
- raise TypeError('dec_activation must be callable')
-
- self.data = data
- self.layer_dimensions = list(layer_dimensions)
- self.keep_prob = float(keep_prob)
- self.conv_activation = conv_activation
- self.dec_activation = dec_activation
-
- self.adj_matrices = None
- self.conv_weights = None
- self.dec_weights = None
- self.build()
-
-
- def build(self) -> None:
- self.adj_matrices = torch.nn.ParameterDict()
- for _, et in self.data.edge_types.items():
- adj_matrices = [
- norm_adj_mat_one_node_type(x) \
- if et.vertex_type_row == et.vertex_type_column \
- else norm_adj_mat_two_node_types(x) \
- for x in et.adjacency_matrices
- ]
- adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
- print('adj_matrices:', adj_matrices)
- self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
- torch.nn.Parameter(adj_matrices, requires_grad=False)
-
- self.conv_weights = torch.nn.ParameterDict()
- for i in range(len(self.layer_dimensions) - 1):
- in_dimension = self.layer_dimensions[i]
- out_dimension = self.layer_dimensions[i + 1]
-
- for _, et in self.data.edge_types.items():
- weights = [ init_glorot(in_dimension, out_dimension) \
- for _ in range(len(et.adjacency_matrices)) ]
- weights = torch.cat(weights, dim=1)
- self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
- torch.nn.Parameter(weights)
-
- self.dec_weights = torch.nn.ParameterDict()
- for _, et in self.data.edge_types.items():
- global_interaction, local_variation = \
- et.decoder_factory(self.layer_dimensions[-1],
- len(et.adjacency_matrices))
- self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
- torch.nn.Parameter(global_interaction)
- for i in range(len(local_variation)):
- self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
- torch.nn.Parameter(local_variation[i])
-
-
- def convolve(self, in_layer_repr: List[torch.Tensor], eval_mode=False) -> \
- List[torch.Tensor]:
-
- cur_layer_repr = in_layer_repr
-
- for i in range(len(self.layer_dimensions) - 1):
- next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
-
- for _, et in self.data.edge_types.items():
- vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
- adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
- conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
-
- num_relation_types = len(et.adjacency_matrices)
- x = cur_layer_repr[vt_col]
- if self.keep_prob != 1 and not eval_mode:
- x = dropout(x, self.keep_prob)
-
- # print('a, Layer:', i, 'x.shape:', x.shape)
-
- x = _mm(x, conv_weights)
- x = torch.split(x,
- x.shape[1] // num_relation_types,
- dim=1)
- x = torch.cat(x)
- x = _mm(adj_matrices, x)
- x = x.view(num_relation_types,
- self.data.vertex_types[vt_row].count,
- self.layer_dimensions[i + 1])
-
- # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
-
- x = x.sum(dim=0)
- x = torch.nn.functional.normalize(x, p=2, dim=1)
- # x = self.rel_activation(x)
- # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
-
- next_layer_repr[vt_row].append(x)
-
- next_layer_repr = [ self.conv_activation(sum(x)) \
- for x in next_layer_repr ]
-
- cur_layer_repr = next_layer_repr
- return cur_layer_repr
-
- def decode(self, last_layer_repr: List[torch.Tensor],
- batch: TrainingBatch, eval_mode=False) -> torch.Tensor:
-
- vt_row = batch.vertex_type_row
- vt_col = batch.vertex_type_column
- rel_idx = batch.relation_type_index
- global_interaction = \
- self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
- local_variation = \
- self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
-
- in_row = last_layer_repr[vt_row]
- in_col = last_layer_repr[vt_col]
-
- if in_row.is_sparse or in_col.is_sparse:
- raise ValueError('Inputs to Model.decode() must be dense')
-
- in_row = in_row[batch.edges[:, 0]]
- in_col = in_col[batch.edges[:, 1]]
-
- if self.keep_prob != 1 and not eval_mode:
- in_row = dropout(in_row, self.keep_prob)
- in_col = dropout(in_col, self.keep_prob)
-
- # in_row = in_row.to_dense()
- # in_col = in_col.to_dense()
-
- print('in_row.is_sparse:', in_row.is_sparse)
- print('in_col.is_sparse:', in_col.is_sparse)
-
- x = torch.mm(in_row, local_variation)
- x = torch.mm(x, global_interaction)
- x = torch.mm(x, local_variation)
- x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]),
- in_col.view(in_col.shape[0], in_col.shape[1], 1))
- x = torch.flatten(x)
-
- x = self.dec_activation(x)
-
- return x
-
-
- def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
- edges = []
- cur_edges = batch.edges
- for _ in range(len(self.layer_dimensions) - 1):
- edges.append(cur_edges)
- key = (batch.vertex_type_row, batch.vertex_type_column)
- tot_conn = self.data.relation_types[key].total_connectivity
- cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
-
-
- def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
- batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
-
- col = batch.vertex_type_column
- rows = batch.edges[:, 0]
-
- columns = batch.edges[:, 1].sum(dim=0).flatten()
- columns = torch.nonzero(columns)
-
- for i in range(len(self.layer_dimensions) - 1):
- pass # columns =
- # TODO: finish
-
- return None
-
-
- def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
-
- col = batch.vertex_type_column
- batch.edges[:, 1]
-
- res = {}
-
- for _, et in self.data.edge_types.items():
- sum_nonzero = _nonzero_sum(et.adjacency_matrices)
- res[et.vertex_type_row, et.vertex_type_column] = \
- [ self.temporary_adjacency_matrix(adj_mat, batch,
- et.total_connectivity) \
- for adj_mat in et.adjacency_matrices ]
-
- return res
-
- def forward(self, initial_repr: List[torch.Tensor],
- batch: TrainingBatch) -> torch.Tensor:
-
- if not isinstance(initial_repr, list):
- raise TypeError('initial_repr must be a list')
-
- if len(initial_repr) != len(self.data.vertex_types):
- raise ValueError('initial_repr must contain representations for all vertex types')
-
- if not isinstance(batch, TrainingBatch):
- raise TypeError('batch must be an instance of TrainingBatch')
-
- adj_matrices = self.temporary_adjacency_matrices(batch)
-
- row_vertices = initial_repr[batch.vertex_type_row]
- column_vertices = initial_repr[batch.vertex_type_column]
|