| @@ -0,0 +1,209 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| from collections import defaultdict | |||||
| from dataclasses import dataclass, field | |||||
| import torch | |||||
| from typing import List, \ | |||||
| Dict, \ | |||||
| Tuple, \ | |||||
| Any, \ | |||||
| Type | |||||
| from .decode import DEDICOMDecoder, \ | |||||
| BilinearDecoder | |||||
| import numpy as np | |||||
| def _equal(x: torch.Tensor, y: torch.Tensor): | |||||
| if x.is_sparse ^ y.is_sparse: | |||||
| raise ValueError('Cannot mix sparse and dense tensors') | |||||
| if not x.is_sparse: | |||||
| return (x == y) | |||||
| return ((x - y).coalesce().values() == 0) | |||||
| @dataclass | |||||
| class NodeType(object): | |||||
| name: str | |||||
| count: int | |||||
| @dataclass | |||||
| class RelationTypeBase(object): | |||||
| name: str | |||||
| node_type_row: int | |||||
| node_type_column: int | |||||
| adjacency_matrix: torch.Tensor | |||||
| adjacency_matrix_backward: torch.Tensor | |||||
| @dataclass | |||||
| class RelationType(RelationTypeBase): | |||||
| pass | |||||
| @dataclass | |||||
| class RelationFamilyBase(object): | |||||
| data: 'Data' | |||||
| name: str | |||||
| node_type_row: int | |||||
| node_type_column: int | |||||
| is_symmetric: bool | |||||
| decoder_class: Type | |||||
| @dataclass | |||||
| class RelationFamily(RelationFamilyBase): | |||||
| relation_types: List[RelationType] = None | |||||
| def __post_init__(self) -> None: | |||||
| if not self.is_symmetric and \ | |||||
| self.decoder_class != DEDICOMDecoder and \ | |||||
| self.decoder_class != BilinearDecoder: | |||||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
| self.relation_types = [] | |||||
| def add_relation_type(self, | |||||
| name: str, adjacency_matrix: torch.Tensor, | |||||
| adjacency_matrix_backward: torch.Tensor = None) -> None: | |||||
| name = str(name) | |||||
| node_type_row = self.node_type_row | |||||
| node_type_column = self.node_type_column | |||||
| if adjacency_matrix is None and adjacency_matrix_backward is None: | |||||
| raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||||
| if adjacency_matrix is not None and \ | |||||
| not isinstance(adjacency_matrix, torch.Tensor): | |||||
| raise ValueError('adjacency_matrix must be a torch.Tensor') | |||||
| if adjacency_matrix_backward is not None \ | |||||
| and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||||
| raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||||
| if adjacency_matrix is not None and \ | |||||
| adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
| self.data.node_types[node_type_column].count): | |||||
| raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||||
| if adjacency_matrix_backward is not None and \ | |||||
| adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||||
| self.data.node_types[node_type_row].count): | |||||
| raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||||
| if node_type_row == node_type_column and \ | |||||
| adjacency_matrix_backward is not None: | |||||
| raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||||
| if self.is_symmetric and adjacency_matrix_backward is not None: | |||||
| raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||||
| if self.is_symmetric and node_type_row == node_type_column and \ | |||||
| not torch.all(_equal(adjacency_matrix, | |||||
| adjacency_matrix.transpose(0, 1))): | |||||
| raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||||
| if not self.is_symmetric and node_type_row != node_type_column and \ | |||||
| adjacency_matrix_backward is None: | |||||
| raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||||
| if self.is_symmetric and node_type_row != node_type_column: | |||||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
| self.relation_types.append(RelationType(name, | |||||
| node_type_row, node_type_column, | |||||
| adjacency_matrix, adjacency_matrix_backward)) | |||||
| def node_name(self, index): | |||||
| return self.data.node_types[index].name | |||||
| def __repr__(self): | |||||
| s = 'Relation family %s' % self.name | |||||
| for r in self.relation_types: | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
| if (r.adjacency_matrix is not None \ | |||||
| and r.adjacency_matrix_backward is not None) \ | |||||
| or self.node_type_row == self.node_type_column \ | |||||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
| self.node_name(self.node_type_column))) | |||||
| return s | |||||
| def repr_indented(self): | |||||
| s = ' - %s' % self.name | |||||
| for r in self.relation_types: | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
| if (r.adjacency_matrix is not None \ | |||||
| and r.adjacency_matrix_backward is not None) \ | |||||
| or self.node_type_row == self.node_type_column \ | |||||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
| self.node_name(self.node_type_column))) | |||||
| return s | |||||
| class Data(object): | |||||
| node_types: List[NodeType] | |||||
| relation_families: List[RelationFamily] | |||||
| def __init__(self) -> None: | |||||
| self.node_types = [] | |||||
| self.relation_families = [] | |||||
| def add_node_type(self, name: str, count: int) -> None: | |||||
| name = str(name) | |||||
| count = int(count) | |||||
| if not name: | |||||
| raise ValueError('You must provide a non-empty node type name') | |||||
| if count <= 0: | |||||
| raise ValueError('You must provide a positive node count') | |||||
| self.node_types.append(NodeType(name, count)) | |||||
| def add_relation_family(self, name: str, node_type_row: int, | |||||
| node_type_column: int, is_symmetric: bool, | |||||
| decoder_class: Type = DEDICOMDecoder): | |||||
| name = str(name) | |||||
| node_type_row = int(node_type_row) | |||||
| node_type_column = int(node_type_column) | |||||
| is_symmetric = bool(is_symmetric) | |||||
| if node_type_row < 0 or node_type_row >= len(self.node_types): | |||||
| raise ValueError('node_type_row outside of the valid range of node types') | |||||
| if node_type_column < 0 or node_type_column >= len(self.node_types): | |||||
| raise ValueError('node_type_column outside of the valid range of node types') | |||||
| fam = RelationFamily(self, name, node_type_row, node_type_column, | |||||
| is_symmetric, decoder_class) | |||||
| self.relation_families.append(fam) | |||||
| return fam | |||||
| def __repr__(self): | |||||
| n = len(self.node_types) | |||||
| if n == 0: | |||||
| return 'Empty Icosagon Data' | |||||
| s = '' | |||||
| s += 'Icosagon Data with:\n' | |||||
| s += '- ' + str(n) + ' node type(s):\n' | |||||
| for nt in self.node_types: | |||||
| s += ' - ' + nt.name + '\n' | |||||
| if len(self.relation_families) == 0: | |||||
| s += '- No relation families\n' | |||||
| return s.strip() | |||||
| s += '- %d relation families:\n' % len(self.relation_families) | |||||
| for fam in self.relation_families: | |||||
| s += fam.repr_indented() + '\n' | |||||
| return s.strip() | |||||
| @@ -0,0 +1,123 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import torch | |||||
| from .weights import init_glorot | |||||
| from .dropout import dropout | |||||
| class DEDICOMDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||||
| self.local_variation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
| for _ in range(num_relation_types) | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | |||||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||||
| relation = torch.diag(self.local_variation[relation_index]) | |||||
| product1 = torch.mm(inputs_row, relation) | |||||
| product2 = torch.mm(product1, self.global_interaction) | |||||
| product3 = torch.mm(product2, relation) | |||||
| rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| return self.activation(rec) | |||||
| class DistMultDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.relation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
| for _ in range(num_relation_types) | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | |||||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||||
| relation = torch.diag(self.relation[relation_index]) | |||||
| intermediate_product = torch.mm(inputs_row, relation) | |||||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| return self.activation(rec) | |||||
| class BilinearDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.relation = torch.nn.ParameterList([ | |||||
| torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||||
| for _ in range(num_relation_types) | |||||
| ]) | |||||
| def forward(self, inputs_row, inputs_col, relation_index): | |||||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||||
| intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| return self.activation(rec) | |||||
| class InnerProductDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| def forward(self, inputs_row, inputs_col, _): | |||||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||||
| rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| return self.activation(rec) | |||||
| @@ -0,0 +1,42 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import torch | |||||
| from .normalize import _sparse_coo_tensor | |||||
| def dropout_sparse(x, keep_prob): | |||||
| x = x.coalesce() | |||||
| i = x._indices() | |||||
| v = x._values() | |||||
| size = x.size() | |||||
| n = keep_prob + torch.rand(len(v)) | |||||
| n = torch.floor(n).to(torch.bool) | |||||
| i = i[:,n] | |||||
| v = v[n] | |||||
| x = _sparse_coo_tensor(i, v, size=size) | |||||
| return x * (1./keep_prob) | |||||
| def dropout_dense(x, keep_prob): | |||||
| # print('dropout_dense()') | |||||
| x = x.clone() | |||||
| i = torch.nonzero(x) | |||||
| n = keep_prob + torch.rand(len(i)) | |||||
| n = (1. - torch.floor(n)).to(torch.bool) | |||||
| x[i[n, 0], i[n, 1]] = 0. | |||||
| return x * (1./keep_prob) | |||||
| def dropout(x, keep_prob): | |||||
| if x.is_sparse: | |||||
| return dropout_sparse(x, keep_prob) | |||||
| else: | |||||
| return dropout_dense(x, keep_prob) | |||||
| @@ -0,0 +1,255 @@ | |||||
| from typing import List, \ | |||||
| Union, \ | |||||
| Callable | |||||
| from .data import Data, \ | |||||
| RelationFamily | |||||
| from .trainprep import PreparedData, \ | |||||
| PreparedRelationFamily | |||||
| import torch | |||||
| from .weights import init_glorot | |||||
| from .normalize import _sparse_coo_tensor | |||||
| import types | |||||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||||
| if len(matrices) == 0: | |||||
| raise ValueError('The list of matrices must be non-empty') | |||||
| if not all(m.is_sparse for m in matrices): | |||||
| raise ValueError('All matrices must be sparse') | |||||
| if not all(len(m.shape) == 2 for m in matrices): | |||||
| raise ValueError('All matrices must be 2D') | |||||
| indices = [] | |||||
| values = [] | |||||
| row_offset = 0 | |||||
| col_offset = 0 | |||||
| for m in matrices: | |||||
| ind = m._indices().clone() | |||||
| ind[0] += row_offset | |||||
| ind[1] += col_offset | |||||
| indices.append(ind) | |||||
| values.append(m._values()) | |||||
| row_offset += m.shape[0] | |||||
| col_offset += m.shape[1] | |||||
| indices = torch.cat(indices, dim=1) | |||||
| values = torch.cat(values) | |||||
| return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||||
| def _cat(matrices: List[torch.Tensor]): | |||||
| if len(matrices) == 0: | |||||
| raise ValueError('Empty list passed to _cat()') | |||||
| n = sum(a.is_sparse for a in matrices) | |||||
| if n != 0 and n != len(matrices): | |||||
| raise ValueError('All matrices must have the same layout (dense or sparse)') | |||||
| if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||||
| raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||||
| if not matrices[0].is_sparse: | |||||
| return torch.cat(matrices) | |||||
| total_rows = sum(a.shape[0] for a in matrices) | |||||
| indices = [] | |||||
| values = [] | |||||
| row_offset = 0 | |||||
| for a in matrices: | |||||
| ind = a._indices().clone() | |||||
| val = a._values() | |||||
| ind[0] += row_offset | |||||
| ind = ind.transpose(0, 1) | |||||
| indices.append(ind) | |||||
| values.append(val) | |||||
| row_offset += a.shape[0] | |||||
| indices = torch.cat(indices).transpose(0, 1) | |||||
| values = torch.cat(values) | |||||
| res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||||
| return res | |||||
| class FastGraphConv(torch.nn.Module): | |||||
| def __init__(self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| adjacency_matrices: List[torch.Tensor], | |||||
| keep_prob: float = 1., | |||||
| activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| in_channels = int(in_channels) | |||||
| out_channels = int(out_channels) | |||||
| if not isinstance(adjacency_matrices, list): | |||||
| raise TypeError('adjacency_matrices must be a list') | |||||
| if len(adjacency_matrices) == 0: | |||||
| raise ValueError('adjacency_matrices must not be empty') | |||||
| if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): | |||||
| raise TypeError('adjacency_matrices elements must be of class torch.Tensor') | |||||
| if not all(m.is_sparse for m in adjacency_matrices): | |||||
| raise ValueError('adjacency_matrices elements must be sparse') | |||||
| keep_prob = float(keep_prob) | |||||
| if not isinstance(activation, types.FunctionType): | |||||
| raise TypeError('activation must be a function') | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.adjacency_matrices = adjacency_matrices | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.num_row_nodes = len(adjacency_matrices[0]) | |||||
| self.num_relation_types = len(adjacency_matrices) | |||||
| self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) | |||||
| self.weights = torch.cat([ | |||||
| init_glorot(in_channels, out_channels) \ | |||||
| for _ in range(self.num_relation_types) | |||||
| ], dim=1) | |||||
| def forward(self, x) -> torch.Tensor: | |||||
| if self.keep_prob < 1.: | |||||
| x = dropout(x, self.keep_prob) | |||||
| res = torch.sparse.mm(x, self.weights) \ | |||||
| if x.is_sparse \ | |||||
| else torch.mm(x, self.weights) | |||||
| res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) | |||||
| res = torch.cat(res) | |||||
| res = torch.sparse.mm(self.adjacency_matrices, res) \ | |||||
| if self.adjacency_matrices.is_sparse \ | |||||
| else torch.mm(self.adjacency_matrices, res) | |||||
| res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) | |||||
| if self.activation is not None: | |||||
| res = self.activation(res) | |||||
| return res | |||||
| class FastConvLayer(torch.nn.Module): | |||||
| def __init__(self, | |||||
| input_dim: List[int], | |||||
| output_dim: List[int], | |||||
| data: Union[Data, PreparedData], | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self._check_params(input_dim, output_dim, data, keep_prob, | |||||
| rel_activation, layer_activation) | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.data = data | |||||
| self.keep_prob = keep_prob | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.is_sparse = False | |||||
| self.next_layer_repr = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.next_layer_repr = torch.nn.ModuleList([ | |||||
| torch.nn.ModuleList() \ | |||||
| for _ in range(len(self.data.node_types)) | |||||
| ]) | |||||
| for fam in self.data.relation_families: | |||||
| self.build_family(fam) | |||||
| def build_family(self, fam) -> None: | |||||
| if fam.node_type_row == fam.node_type_column: | |||||
| self.build_fam_one_node_type(fam) | |||||
| else: | |||||
| self.build_fam_two_node_types(fam) | |||||
| def build_fam_one_node_type(self, fam) -> None: | |||||
| adjacency_matrices = [ | |||||
| r.adjacency_matrix \ | |||||
| for r in fam.relation_types | |||||
| ] | |||||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], | |||||
| adjacency_matrices, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv.input_node_type = fam.node_type_column | |||||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||||
| def build_fam_two_node_types(self, fam) -> None: | |||||
| adjacency_matrices = [ | |||||
| r.adjacency_matrix \ | |||||
| for r in fam.relation_types \ | |||||
| if r.adjacency_matrix is not None | |||||
| ] | |||||
| adjacency_matrices_backward = [ | |||||
| r.adjacency_matrix_backward \ | |||||
| for r in fam.relation_types \ | |||||
| if r.adjacency_matrix_backward is not None | |||||
| ] | |||||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], | |||||
| adjacency_matrices, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||||
| self.output_dim[fam.node_type_column], | |||||
| adjacency_matrices_backward, | |||||
| self.keep_prob, | |||||
| self.rel_activation) | |||||
| conv.input_node_type = fam.node_type_column | |||||
| conv_backward.input_node_type = fam.node_type_row | |||||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||||
| self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||||
| def forward(self, prev_layer_repr): | |||||
| next_layer_repr = [ [] \ | |||||
| for _ in range(len(self.data.node_types)) ] | |||||
| for output_node_type in range(len(self.data.node_types)): | |||||
| for conv in self.next_layer_repr[output_node_type]: | |||||
| rep = conv(prev_layer_repr[conv.input_node_type]) | |||||
| rep = torch.sum(rep, dim=0) | |||||
| rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||||
| next_layer_repr[output_node_type].append(rep) | |||||
| if len(next_layer_repr[output_node_type]) == 0: | |||||
| next_layer_repr[output_node_type] = \ | |||||
| torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||||
| else: | |||||
| next_layer_repr[output_node_type] = \ | |||||
| sum(next_layer_repr[output_node_type]) | |||||
| next_layer_repr[output_node_type] = \ | |||||
| self.layer_activation(next_layer_repr[output_node_type]) | |||||
| return next_layer_repr | |||||
| @staticmethod | |||||
| def _check_params(input_dim, output_dim, data, keep_prob, | |||||
| rel_activation, layer_activation): | |||||
| if not isinstance(input_dim, list): | |||||
| raise ValueError('input_dim must be a list') | |||||
| if not output_dim: | |||||
| raise ValueError('output_dim must be specified') | |||||
| if not isinstance(output_dim, list): | |||||
| output_dim = [output_dim] * len(data.node_types) | |||||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||||
| raise ValueError('data must be of type Data or PreparedData') | |||||
| @@ -0,0 +1,138 @@ | |||||
| import torch | |||||
| from typing import List | |||||
| from .trainprep import PreparedData | |||||
| from dataclasses import dataclass | |||||
| import random | |||||
| from collections import defaultdict | |||||
| @dataclass | |||||
| class TrainingBatch(object): | |||||
| relation_family_index: int | |||||
| relation_type_index: int | |||||
| node_type_row: int | |||||
| node_type_column: int | |||||
| edges: torch.Tensor | |||||
| class FastBatcher(object): | |||||
| def __init__(self, | |||||
| prep_d: PreparedData, | |||||
| batch_size: int) -> None: | |||||
| if not isinstance(prep_d, PreparedData): | |||||
| raise TypeError('prep_d must be an instance of PreparedData') | |||||
| self.prep_d = prep_d | |||||
| self.batch_size = int(batch_size) | |||||
| self.edges = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.edges = [] | |||||
| for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||||
| edges = [] | |||||
| targets = [] | |||||
| edges_back = [] | |||||
| targets_back = [] | |||||
| for rel_idx, rel in enumerate(fam.relation_types): | |||||
| edges.append(rel.edges_pos.train) | |||||
| edges.append(rel.edges_neg.train) | |||||
| targets.append(torch.ones(len(rel.edges_pos.train))) | |||||
| targets.append(torch.zeros(len(rel.edges_neg.train))) | |||||
| edges_back.append(rel.edges_back_pos.train) | |||||
| edges_back.append(rel.edges_back_neg.train) | |||||
| targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) | |||||
| targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) | |||||
| edges = torch.cat(edges) | |||||
| targets = torch.cat(targets) | |||||
| edges_back = torch.cat(edges_back) | |||||
| targets_back = torch.cat(targets_back) | |||||
| order = torch.randperm(len(edges)) | |||||
| edges = edges[order] | |||||
| targets = targets[order] | |||||
| order_back = torch.randperm(len(edges_back)) | |||||
| edges_back = edges_back[order_back] | |||||
| targets_back = targets_back[order_back] | |||||
| self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, | |||||
| 'edges': edges, 'targets': targets, 'ofs': 0}) | |||||
| self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, | |||||
| 'edges': edges_back, 'targets': targets_back, 'ofs': 0}) | |||||
| def __iter__(self): | |||||
| while True: | |||||
| edges = [ e for e in self.edges \ | |||||
| if e['ofs'] < len(e['edges']) ] | |||||
| # TODO: need to finish this | |||||
| def __iter_old__(self): | |||||
| edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] | |||||
| offsets = {} | |||||
| orders = {} | |||||
| done = {} | |||||
| for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||||
| for rel_idx, rel in enumerate(fam.relation_types): | |||||
| for et in edge_types: | |||||
| done[fam_idx, rel_idx, et] = False | |||||
| while True: | |||||
| fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() | |||||
| fam = self.prep_d.relation_families[fam_idx] | |||||
| rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() | |||||
| rel = fam.relation_types[rel_idx] | |||||
| et = random.choice(edge_types) | |||||
| edges = getattr(rel, et).train | |||||
| key = (fam_idx, rel_idx, et) | |||||
| if key not in orders: | |||||
| orders[key] = torch.randperm(len(edges)) | |||||
| offsets[key] = 0 | |||||
| ord = orders[key] | |||||
| ofs = offsets[key] | |||||
| nt_row = rel.node_type_row | |||||
| nt_col = rel.node_type_column | |||||
| if 'back' in et: | |||||
| nt_row, nt_col = nt_col, nt_row | |||||
| if ofs < len(edges): | |||||
| offsets[key] += self.batch_size | |||||
| ord = ord[ofs:ofs+self.batch_size] | |||||
| edges = edges[ord] | |||||
| yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) | |||||
| else: | |||||
| done[key] = True | |||||
| for fam in self.prep_d.relation_families: | |||||
| edges = [] | |||||
| for rel in fam.relation_types: | |||||
| edges.append(rel.edges_pos.train) | |||||
| edges.append(rel.edges_back_pos.train) | |||||
| edges.append(rel.edges_neg.train) | |||||
| edges.append(rel.edges_back_neg.train) | |||||
| edges = torch.cat(e) | |||||
| class FastDecLayer(torch.nn.Module): | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| def forward(self, | |||||
| last_layer_repr: List[torch.Tensor], | |||||
| training_batch: TrainingBatch): | |||||
| @@ -0,0 +1,166 @@ | |||||
| from .fastmodel import FastModel | |||||
| from .trainprep import PreparedData | |||||
| import torch | |||||
| from typing import Callable | |||||
| from types import FunctionType | |||||
| import time | |||||
| import random | |||||
| class FastBatcher(object): | |||||
| def __init__(self, prep_d: PreparedData, batch_size: int, | |||||
| shuffle: bool, generator: torch.Generator, | |||||
| part_type: str) -> None: | |||||
| if not isinstance(prep_d, PreparedData): | |||||
| raise TypeError('prep_d must be an instance of PreparedData') | |||||
| if not isinstance(generator, torch.Generator): | |||||
| raise TypeError('generator must be an instance of torch.Generator') | |||||
| if part_type not in ['train', 'val', 'test']: | |||||
| raise ValueError('part_type must be set to train, val or test') | |||||
| self.prep_d = prep_d | |||||
| self.batch_size = int(batch_size) | |||||
| self.shuffle = bool(shuffle) | |||||
| self.generator = generator | |||||
| self.part_type = part_type | |||||
| self.edges = None | |||||
| self.targets = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.edges = [] | |||||
| self.targets = [] | |||||
| for fam in self.prep_d.relation_families: | |||||
| edges = [] | |||||
| targets = [] | |||||
| for i, rel in enumerate(fam.relation_types): | |||||
| edges_pos = getattr(rel.edges_pos, self.part_type) | |||||
| edges_neg = getattr(rel.edges_neg, self.part_type) | |||||
| edges_back_pos = getattr(rel.edges_back_pos, self.part_type) | |||||
| edges_back_neg = getattr(rel.edges_back_neg, self.part_type) | |||||
| e = torch.cat([ edges_pos, | |||||
| torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) | |||||
| e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) | |||||
| t = torch.ones(len(e)) | |||||
| edges.append(e) | |||||
| targets.append(t) | |||||
| e = torch.cat([ edges_neg, | |||||
| torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) | |||||
| e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) | |||||
| t = torch.zeros(len(e)) | |||||
| edges.append(e) | |||||
| targets.append(t) | |||||
| edges = torch.cat(edges) | |||||
| targets = torch.cat(targets) | |||||
| self.edges.append(edges) | |||||
| self.targets.append(targets) | |||||
| # print(self.edges) | |||||
| # print(self.targets) | |||||
| if self.shuffle: | |||||
| self.shuffle_families() | |||||
| def shuffle_families(self): | |||||
| for i in range(len(self.edges)): | |||||
| edges = self.edges[i] | |||||
| targets = self.targets[i] | |||||
| order = torch.randperm(len(edges), generator=self.generator) | |||||
| self.edges[i] = edges[order] | |||||
| self.targets[i] = targets[order] | |||||
| def __iter__(self): | |||||
| offsets = [ 0 for _ in self.edges ] | |||||
| while True: | |||||
| choice = [ i for i in range(len(offsets)) \ | |||||
| if offsets[i] < len(self.edges[i]) ] | |||||
| if len(choice) == 0: | |||||
| break | |||||
| fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() | |||||
| ofs = offsets[fam_idx] | |||||
| edges = self.edges[fam_idx][ofs:ofs + self.batch_size] | |||||
| targets = self.targets[fam_idx][ofs:ofs + self.batch_size] | |||||
| offsets[fam_idx] += self.batch_size | |||||
| yield (fam_idx, edges, targets) | |||||
| class FastLoop(object): | |||||
| def __init__( | |||||
| self, | |||||
| model: FastModel, | |||||
| lr: float = 0.001, | |||||
| loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||||
| torch.nn.functional.binary_cross_entropy_with_logits, | |||||
| batch_size: int = 100, | |||||
| shuffle: bool = True, | |||||
| generator: torch.Generator = None) -> None: | |||||
| self._check_params(model, loss, generator) | |||||
| self.model = model | |||||
| self.lr = float(lr) | |||||
| self.loss = loss | |||||
| self.batch_size = int(batch_size) | |||||
| self.shuffle = bool(shuffle) | |||||
| self.generator = generator or torch.default_generator | |||||
| self.opt = None | |||||
| self.build() | |||||
| def _check_params(self, model, loss, generator): | |||||
| if not isinstance(model, FastModel): | |||||
| raise TypeError('model must be an instance of FastModel') | |||||
| if not isinstance(loss, FunctionType): | |||||
| raise TypeError('loss must be a function') | |||||
| if generator is not None and not isinstance(generator, torch.Generator): | |||||
| raise TypeError('generator must be an instance of torch.Generator') | |||||
| def build(self) -> None: | |||||
| opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |||||
| self.opt = opt | |||||
| def run_epoch(self): | |||||
| prep_d = self.model.prep_d | |||||
| batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, | |||||
| shuffle = self.shuffle, generator=self.generator) | |||||
| # pred = self.model(None) | |||||
| # n = len(list(iter(batch))) | |||||
| loss_sum = 0 | |||||
| for fam_idx, edges, targets in batcher: | |||||
| self.opt.zero_grad() | |||||
| pred = self.model(None) | |||||
| # process pred, get input and targets | |||||
| input = pred[fam_idx][edges[:, 0], edges[:, 1]] | |||||
| loss = self.loss(input, targets) | |||||
| loss.backward() | |||||
| self.opt.step() | |||||
| loss_sum += loss.detach().cpu().item() | |||||
| return loss_sum | |||||
| def train(self, max_epochs): | |||||
| best_loss = None | |||||
| best_epoch = None | |||||
| for i in range(max_epochs): | |||||
| loss = self.run_epoch() | |||||
| if best_loss is None or loss < best_loss: | |||||
| best_loss = loss | |||||
| best_epoch = i | |||||
| return loss, best_loss, best_epoch | |||||
| @@ -0,0 +1,79 @@ | |||||
| from .fastconv import FastConvLayer | |||||
| from .bulkdec import BulkDecodeLayer | |||||
| from .input import OneHotInputLayer | |||||
| from .trainprep import PreparedData | |||||
| import torch | |||||
| import types | |||||
| from typing import List, \ | |||||
| Union, \ | |||||
| Callable | |||||
| class FastModel(torch.nn.Module): | |||||
| def __init__(self, prep_d: PreparedData, | |||||
| layer_dimensions: List[int] = [32, 64], | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| self._check_params(prep_d, layer_dimensions, rel_activation, | |||||
| layer_activation, dec_activation) | |||||
| self.prep_d = prep_d | |||||
| self.layer_dimensions = layer_dimensions | |||||
| self.keep_prob = float(keep_prob) | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.dec_activation = dec_activation | |||||
| self.seq = None | |||||
| self.build() | |||||
| def build(self): | |||||
| in_layer = OneHotInputLayer(self.prep_d) | |||||
| last_output_dim = in_layer.output_dim | |||||
| seq = [ in_layer ] | |||||
| for dim in self.layer_dimensions: | |||||
| conv_layer = FastConvLayer(input_dim = last_output_dim, | |||||
| output_dim = [dim] * len(self.prep_d.node_types), | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| rel_activation = self.rel_activation, | |||||
| layer_activation = self.layer_activation) | |||||
| last_output_dim = conv_layer.output_dim | |||||
| seq.append(conv_layer) | |||||
| dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||||
| data = self.prep_d, | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.dec_activation) | |||||
| seq.append(dec_layer) | |||||
| seq = torch.nn.Sequential(*seq) | |||||
| self.seq = seq | |||||
| def forward(self, _): | |||||
| return self.seq(None) | |||||
| def _check_params(self, prep_d, layer_dimensions, rel_activation, | |||||
| layer_activation, dec_activation): | |||||
| if not isinstance(prep_d, PreparedData): | |||||
| raise TypeError('prep_d must be an instanced of PreparedData') | |||||
| if not isinstance(layer_dimensions, list): | |||||
| raise TypeError('layer_dimensions must be a list') | |||||
| if not isinstance(rel_activation, types.FunctionType): | |||||
| raise TypeError('rel_activation must be a function') | |||||
| if not isinstance(layer_activation, types.FunctionType): | |||||
| raise TypeError('layer_activation must be a function') | |||||
| if not isinstance(dec_activation, types.FunctionType): | |||||
| raise TypeError('dec_activation must be a function') | |||||
| @@ -0,0 +1,79 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import torch | |||||
| from typing import Union, \ | |||||
| List | |||||
| from .data import Data | |||||
| class InputLayer(torch.nn.Module): | |||||
| def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, | |||||
| **kwargs) -> None: | |||||
| output_dim = output_dim or \ | |||||
| list(map(lambda a: a.count, data.node_types)) | |||||
| if not isinstance(output_dim, list): | |||||
| output_dim = [output_dim,] * len(data.node_types) | |||||
| super().__init__(**kwargs) | |||||
| self.output_dim = output_dim | |||||
| self.data = data | |||||
| self.is_sparse=False | |||||
| self.node_reps = None | |||||
| self.build() | |||||
| def build(self) -> None: | |||||
| self.node_reps = [] | |||||
| for i, nt in enumerate(self.data.node_types): | |||||
| reps = torch.rand(nt.count, self.output_dim[i]) | |||||
| reps = torch.nn.Parameter(reps) | |||||
| self.register_parameter('node_reps[%d]' % i, reps) | |||||
| self.node_reps.append(reps) | |||||
| def forward(self, x) -> List[torch.nn.Parameter]: | |||||
| return self.node_reps | |||||
| def __repr__(self) -> str: | |||||
| s = '' | |||||
| s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim | |||||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
| for nt in self.data.node_types: | |||||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
| return s.strip() | |||||
| class OneHotInputLayer(torch.nn.Module): | |||||
| def __init__(self, data: Data, **kwargs) -> None: | |||||
| output_dim = [ a.count for a in data.node_types ] | |||||
| super().__init__(**kwargs) | |||||
| self.output_dim = output_dim | |||||
| self.data = data | |||||
| self.is_sparse=True | |||||
| self.node_reps = None | |||||
| self.build() | |||||
| def build(self) -> None: | |||||
| self.node_reps = torch.nn.ParameterList() | |||||
| for i, nt in enumerate(self.data.node_types): | |||||
| reps = torch.eye(nt.count).to_sparse() | |||||
| reps = torch.nn.Parameter(reps, requires_grad=False) | |||||
| # self.register_parameter('node_reps[%d]' % i, reps) | |||||
| self.node_reps.append(reps) | |||||
| def forward(self, x) -> List[torch.nn.Parameter]: | |||||
| return self.node_reps | |||||
| def __repr__(self) -> str: | |||||
| s = '' | |||||
| s += 'Icosagon one-hot input layer\n' | |||||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
| for nt in self.data.node_types: | |||||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
| return s.strip() | |||||
| @@ -0,0 +1,145 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import numpy as np | |||||
| import scipy.sparse as sp | |||||
| import torch | |||||
| def _check_tensor(adj_mat): | |||||
| if not isinstance(adj_mat, torch.Tensor): | |||||
| raise ValueError('adj_mat must be a torch.Tensor') | |||||
| def _check_sparse(adj_mat): | |||||
| if not adj_mat.is_sparse: | |||||
| raise ValueError('adj_mat must be sparse') | |||||
| def _check_dense(adj_mat): | |||||
| if adj_mat.is_sparse: | |||||
| raise ValueError('adj_mat must be dense') | |||||
| def _check_square(adj_mat): | |||||
| if len(adj_mat.shape) != 2 or \ | |||||
| adj_mat.shape[0] != adj_mat.shape[1]: | |||||
| raise ValueError('adj_mat must be a square matrix') | |||||
| def _check_2d(adj_mat): | |||||
| if len(adj_mat.shape) != 2: | |||||
| raise ValueError('adj_mat must be a square matrix') | |||||
| def _sparse_coo_tensor(indices, values, size): | |||||
| ctor = { torch.float32: torch.sparse.FloatTensor, | |||||
| torch.float32: torch.sparse.DoubleTensor, | |||||
| torch.uint8: torch.sparse.ByteTensor, | |||||
| torch.long: torch.sparse.LongTensor, | |||||
| torch.int: torch.sparse.IntTensor, | |||||
| torch.short: torch.sparse.ShortTensor, | |||||
| torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||||
| return ctor(indices, values, size) | |||||
| def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_sparse(adj_mat) | |||||
| _check_square(adj_mat) | |||||
| adj_mat = adj_mat.coalesce() | |||||
| indices = adj_mat.indices() | |||||
| values = adj_mat.values() | |||||
| eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, | |||||
| device=adj_mat.device).view(1, -1) | |||||
| eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||||
| eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, | |||||
| device=adj_mat.device) | |||||
| indices = torch.cat((indices, eye_indices), 1) | |||||
| values = torch.cat((values, eye_values), 0) | |||||
| adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_sparse(adj_mat) | |||||
| _check_square(adj_mat) | |||||
| adj_mat = add_eye_sparse(adj_mat) | |||||
| adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat) | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_dense(adj_mat) | |||||
| _check_square(adj_mat) | |||||
| adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, | |||||
| device=adj_mat.device) | |||||
| adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_square(adj_mat) | |||||
| if adj_mat.is_sparse: | |||||
| return norm_adj_mat_one_node_type_sparse(adj_mat) | |||||
| else: | |||||
| return norm_adj_mat_one_node_type_dense(adj_mat) | |||||
| def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_sparse(adj_mat) | |||||
| _check_2d(adj_mat) | |||||
| adj_mat = adj_mat.coalesce() | |||||
| indices = adj_mat.indices() | |||||
| values = adj_mat.values() | |||||
| degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) | |||||
| degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) | |||||
| degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) | |||||
| degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) | |||||
| values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) | |||||
| adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||||
| return adj_mat | |||||
| def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_dense(adj_mat) | |||||
| _check_2d(adj_mat) | |||||
| degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32) | |||||
| degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32) | |||||
| degrees_row = torch.sqrt(degrees_row) | |||||
| degrees_col = torch.sqrt(degrees_col) | |||||
| adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row | |||||
| adj_mat = adj_mat / degrees_col | |||||
| return adj_mat | |||||
| def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| _check_tensor(adj_mat) | |||||
| _check_2d(adj_mat) | |||||
| if adj_mat.is_sparse: | |||||
| return norm_adj_mat_two_node_types_sparse(adj_mat) | |||||
| else: | |||||
| return norm_adj_mat_two_node_types_dense(adj_mat) | |||||
| @@ -0,0 +1,47 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from typing import List, \ | |||||
| Union | |||||
| def fixed_unigram_candidate_sampler( | |||||
| true_classes: Union[np.array, torch.Tensor], | |||||
| unigrams: List[Union[int, float]], | |||||
| distortion: float = 1.): | |||||
| if isinstance(true_classes, torch.Tensor): | |||||
| true_classes = true_classes.detach().cpu().numpy() | |||||
| if isinstance(unigrams, torch.Tensor): | |||||
| unigrams = unigrams.detach().cpu().numpy() | |||||
| if len(true_classes.shape) != 2: | |||||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||||
| num_samples = true_classes.shape[0] | |||||
| unigrams = np.array(unigrams) | |||||
| if distortion != 1.: | |||||
| unigrams = unigrams.astype(np.float64) ** distortion | |||||
| # print('unigrams:', unigrams) | |||||
| indices = np.arange(num_samples) | |||||
| result = np.zeros(num_samples, dtype=np.int64) | |||||
| while len(indices) > 0: | |||||
| # print('len(indices):', len(indices)) | |||||
| sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||||
| candidates = np.array(list(sampler)) | |||||
| candidates = np.reshape(candidates, (len(indices), 1)) | |||||
| # print('candidates:', candidates) | |||||
| # print('true_classes:', true_classes[indices, :]) | |||||
| result[indices] = candidates.T | |||||
| mask = (candidates == true_classes[indices, :]) | |||||
| mask = mask.sum(1).astype(np.bool) | |||||
| # print('mask:', mask) | |||||
| indices = indices[mask] | |||||
| return torch.tensor(result) | |||||
| @@ -0,0 +1,215 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| from .sampling import fixed_unigram_candidate_sampler | |||||
| import torch | |||||
| from dataclasses import dataclass, \ | |||||
| field | |||||
| from typing import Any, \ | |||||
| List, \ | |||||
| Tuple, \ | |||||
| Dict | |||||
| from .data import NodeType, \ | |||||
| RelationType, \ | |||||
| RelationTypeBase, \ | |||||
| RelationFamily, \ | |||||
| RelationFamilyBase, \ | |||||
| Data | |||||
| from collections import defaultdict | |||||
| from .normalize import norm_adj_mat_one_node_type, \ | |||||
| norm_adj_mat_two_node_types | |||||
| import numpy as np | |||||
| @dataclass | |||||
| class TrainValTest(object): | |||||
| train: Any | |||||
| val: Any | |||||
| test: Any | |||||
| @dataclass | |||||
| class PreparedRelationType(RelationTypeBase): | |||||
| edges_pos: TrainValTest | |||||
| edges_neg: TrainValTest | |||||
| edges_back_pos: TrainValTest | |||||
| edges_back_neg: TrainValTest | |||||
| @dataclass | |||||
| class PreparedRelationFamily(RelationFamilyBase): | |||||
| relation_types: List[PreparedRelationType] | |||||
| @dataclass | |||||
| class PreparedData(object): | |||||
| node_types: List[NodeType] | |||||
| relation_families: List[PreparedRelationFamily] | |||||
| def _empty_edge_list_tvt() -> TrainValTest: | |||||
| return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||||
| def train_val_test_split_edges(edges: torch.Tensor, | |||||
| ratios: TrainValTest) -> TrainValTest: | |||||
| if not isinstance(edges, torch.Tensor): | |||||
| raise ValueError('edges must be a torch.Tensor') | |||||
| if len(edges.shape) != 2 or edges.shape[1] != 2: | |||||
| raise ValueError('edges shape must be (num_edges, 2)') | |||||
| if not isinstance(ratios, TrainValTest): | |||||
| raise ValueError('ratios must be a TrainValTest') | |||||
| if ratios.train + ratios.val + ratios.test != 1.0: | |||||
| raise ValueError('Train, validation and test ratios must add up to 1') | |||||
| order = torch.randperm(len(edges)) | |||||
| edges = edges[order, :] | |||||
| n = round(len(edges) * ratios.train) | |||||
| edges_train = edges[:n] | |||||
| n_1 = round(len(edges) * (ratios.train + ratios.val)) | |||||
| edges_val = edges[n:n_1] | |||||
| edges_test = edges[n_1:] | |||||
| return TrainValTest(edges_train, edges_val, edges_test) | |||||
| def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| if adj_mat.is_sparse: | |||||
| adj_mat = adj_mat.coalesce() | |||||
| degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, | |||||
| device=adj_mat.device) | |||||
| degrees = degrees.index_add(0, adj_mat.indices()[1], | |||||
| torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, | |||||
| device=adj_mat.device)) | |||||
| edges_pos = adj_mat.indices().transpose(0, 1) | |||||
| else: | |||||
| degrees = adj_mat.sum(0) | |||||
| edges_pos = torch.nonzero(adj_mat) | |||||
| return edges_pos, degrees | |||||
| def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
| ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: | |||||
| if not isinstance(adj_mat, torch.Tensor): | |||||
| raise ValueError('adj_mat must be a torch.Tensor') | |||||
| edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||||
| neg_neighbors = fixed_unigram_candidate_sampler( | |||||
| edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||||
| print(edges_pos.dtype) | |||||
| print(neg_neighbors.dtype) | |||||
| edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) | |||||
| edges_pos = train_val_test_split_edges(edges_pos, ratios) | |||||
| edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||||
| adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | |||||
| values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, | |||||
| device=adj_mat.device) | |||||
| return adj_mat_train, edges_pos, edges_neg | |||||
| def prep_rel_one_node_type(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| adj_mat = r.adjacency_matrix | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| print('adj_mat_train:', adj_mat_train) | |||||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
| adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||||
| edges_back_pos, edges_back_neg) | |||||
| def prep_rel_two_node_types_sym(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| adj_mat = r.adjacency_matrix | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| edges_back_pos, edges_back_neg = \ | |||||
| _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| return PreparedRelationType(r.name, r.node_type_row, | |||||
| r.node_type_column, | |||||
| norm_adj_mat_two_node_types(adj_mat_train), | |||||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
| def prep_rel_two_node_types_asym(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| if r.adjacency_matrix is not None: | |||||
| adj_mat_train, edges_pos, edges_neg =\ | |||||
| prepare_adj_mat(r.adjacency_matrix, ratios) | |||||
| else: | |||||
| adj_mat_train, edges_pos, edges_neg = \ | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| if r.adjacency_matrix_backward is not None: | |||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||||
| else: | |||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| return PreparedRelationType(r.name, r.node_type_row, | |||||
| r.node_type_column, | |||||
| norm_adj_mat_two_node_types(adj_mat_train), | |||||
| norm_adj_mat_two_node_types(adj_mat_back_train), | |||||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
| def prepare_relation_type(r: RelationType, | |||||
| ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||||
| if not isinstance(r, RelationType): | |||||
| raise ValueError('r must be a RelationType') | |||||
| if not isinstance(ratios, TrainValTest): | |||||
| raise ValueError('ratios must be a TrainValTest') | |||||
| if r.node_type_row == r.node_type_column: | |||||
| return prep_rel_one_node_type(r, ratios) | |||||
| elif is_symmetric: | |||||
| return prep_rel_two_node_types_sym(r, ratios) | |||||
| else: | |||||
| return prep_rel_two_node_types_asym(r, ratios) | |||||
| def prepare_relation_family(fam: RelationFamily, | |||||
| ratios: TrainValTest) -> PreparedRelationFamily: | |||||
| relation_types = [] | |||||
| for r in fam.relation_types: | |||||
| relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||||
| return PreparedRelationFamily(fam.data, fam.name, | |||||
| fam.node_type_row, fam.node_type_column, | |||||
| fam.is_symmetric, fam.decoder_class, | |||||
| relation_types) | |||||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
| if not isinstance(data, Data): | |||||
| raise ValueError('data must be of class Data') | |||||
| relation_families = [ prepare_relation_family(fam, ratios) \ | |||||
| for fam in data.relation_families ] | |||||
| return PreparedData(data.node_types, relation_families) | |||||
| @@ -0,0 +1,19 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import torch | |||||
| import numpy as np | |||||
| def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||||
| """Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||||
| initialization. | |||||
| """ | |||||
| init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||||
| initial = -init_range + 2 * init_range * \ | |||||
| torch.rand(( in_channels, out_channels ), dtype=dtype) | |||||
| initial = initial.requires_grad_(True) | |||||
| return initial | |||||