| @@ -0,0 +1,209 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| from collections import defaultdict | |||
| from dataclasses import dataclass, field | |||
| import torch | |||
| from typing import List, \ | |||
| Dict, \ | |||
| Tuple, \ | |||
| Any, \ | |||
| Type | |||
| from .decode import DEDICOMDecoder, \ | |||
| BilinearDecoder | |||
| import numpy as np | |||
| def _equal(x: torch.Tensor, y: torch.Tensor): | |||
| if x.is_sparse ^ y.is_sparse: | |||
| raise ValueError('Cannot mix sparse and dense tensors') | |||
| if not x.is_sparse: | |||
| return (x == y) | |||
| return ((x - y).coalesce().values() == 0) | |||
| @dataclass | |||
| class NodeType(object): | |||
| name: str | |||
| count: int | |||
| @dataclass | |||
| class RelationTypeBase(object): | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| adjacency_matrix: torch.Tensor | |||
| adjacency_matrix_backward: torch.Tensor | |||
| @dataclass | |||
| class RelationType(RelationTypeBase): | |||
| pass | |||
| @dataclass | |||
| class RelationFamilyBase(object): | |||
| data: 'Data' | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| is_symmetric: bool | |||
| decoder_class: Type | |||
| @dataclass | |||
| class RelationFamily(RelationFamilyBase): | |||
| relation_types: List[RelationType] = None | |||
| def __post_init__(self) -> None: | |||
| if not self.is_symmetric and \ | |||
| self.decoder_class != DEDICOMDecoder and \ | |||
| self.decoder_class != BilinearDecoder: | |||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
| self.relation_types = [] | |||
| def add_relation_type(self, | |||
| name: str, adjacency_matrix: torch.Tensor, | |||
| adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
| name = str(name) | |||
| node_type_row = self.node_type_row | |||
| node_type_column = self.node_type_column | |||
| if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
| raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
| if adjacency_matrix is not None and \ | |||
| not isinstance(adjacency_matrix, torch.Tensor): | |||
| raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
| if adjacency_matrix_backward is not None \ | |||
| and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
| raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
| if adjacency_matrix is not None and \ | |||
| adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
| self.data.node_types[node_type_column].count): | |||
| raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||
| if adjacency_matrix_backward is not None and \ | |||
| adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||
| self.data.node_types[node_type_row].count): | |||
| raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||
| if node_type_row == node_type_column and \ | |||
| adjacency_matrix_backward is not None: | |||
| raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||
| if self.is_symmetric and adjacency_matrix_backward is not None: | |||
| raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||
| if self.is_symmetric and node_type_row == node_type_column and \ | |||
| not torch.all(_equal(adjacency_matrix, | |||
| adjacency_matrix.transpose(0, 1))): | |||
| raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||
| if not self.is_symmetric and node_type_row != node_type_column and \ | |||
| adjacency_matrix_backward is None: | |||
| raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||
| if self.is_symmetric and node_type_row != node_type_column: | |||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| self.relation_types.append(RelationType(name, | |||
| node_type_row, node_type_column, | |||
| adjacency_matrix, adjacency_matrix_backward)) | |||
| def node_name(self, index): | |||
| return self.data.node_types[index].name | |||
| def __repr__(self): | |||
| s = 'Relation family %s' % self.name | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.node_type_row == self.node_type_column \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| def repr_indented(self): | |||
| s = ' - %s' % self.name | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.node_type_row == self.node_type_column \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| class Data(object): | |||
| node_types: List[NodeType] | |||
| relation_families: List[RelationFamily] | |||
| def __init__(self) -> None: | |||
| self.node_types = [] | |||
| self.relation_families = [] | |||
| def add_node_type(self, name: str, count: int) -> None: | |||
| name = str(name) | |||
| count = int(count) | |||
| if not name: | |||
| raise ValueError('You must provide a non-empty node type name') | |||
| if count <= 0: | |||
| raise ValueError('You must provide a positive node count') | |||
| self.node_types.append(NodeType(name, count)) | |||
| def add_relation_family(self, name: str, node_type_row: int, | |||
| node_type_column: int, is_symmetric: bool, | |||
| decoder_class: Type = DEDICOMDecoder): | |||
| name = str(name) | |||
| node_type_row = int(node_type_row) | |||
| node_type_column = int(node_type_column) | |||
| is_symmetric = bool(is_symmetric) | |||
| if node_type_row < 0 or node_type_row >= len(self.node_types): | |||
| raise ValueError('node_type_row outside of the valid range of node types') | |||
| if node_type_column < 0 or node_type_column >= len(self.node_types): | |||
| raise ValueError('node_type_column outside of the valid range of node types') | |||
| fam = RelationFamily(self, name, node_type_row, node_type_column, | |||
| is_symmetric, decoder_class) | |||
| self.relation_families.append(fam) | |||
| return fam | |||
| def __repr__(self): | |||
| n = len(self.node_types) | |||
| if n == 0: | |||
| return 'Empty Icosagon Data' | |||
| s = '' | |||
| s += 'Icosagon Data with:\n' | |||
| s += '- ' + str(n) + ' node type(s):\n' | |||
| for nt in self.node_types: | |||
| s += ' - ' + nt.name + '\n' | |||
| if len(self.relation_families) == 0: | |||
| s += '- No relation families\n' | |||
| return s.strip() | |||
| s += '- %d relation families:\n' % len(self.relation_families) | |||
| for fam in self.relation_families: | |||
| s += fam.repr_indented() + '\n' | |||
| return s.strip() | |||
| @@ -0,0 +1,123 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| from .weights import init_glorot | |||
| from .dropout import dropout | |||
| class DEDICOMDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||
| self.local_variation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| relation = torch.diag(self.local_variation[relation_index]) | |||
| product1 = torch.mm(inputs_row, relation) | |||
| product2 = torch.mm(product1, self.global_interaction) | |||
| product3 = torch.mm(product2, relation) | |||
| rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| class DistMultDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.relation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| relation = torch.diag(self.relation[relation_index]) | |||
| intermediate_product = torch.mm(inputs_row, relation) | |||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| class BilinearDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.relation = torch.nn.ParameterList([ | |||
| torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||
| for _ in range(num_relation_types) | |||
| ]) | |||
| def forward(self, inputs_row, inputs_col, relation_index): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| class InnerProductDecoder(torch.nn.Module): | |||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
| def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
| activation=torch.sigmoid, **kwargs): | |||
| super().__init__(**kwargs) | |||
| self.input_dim = input_dim | |||
| self.num_relation_types = num_relation_types | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| def forward(self, inputs_row, inputs_col, _): | |||
| inputs_row = dropout(inputs_row, self.keep_prob) | |||
| inputs_col = dropout(inputs_col, self.keep_prob) | |||
| rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
| rec = torch.flatten(rec) | |||
| return self.activation(rec) | |||
| @@ -0,0 +1,42 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| from .normalize import _sparse_coo_tensor | |||
| def dropout_sparse(x, keep_prob): | |||
| x = x.coalesce() | |||
| i = x._indices() | |||
| v = x._values() | |||
| size = x.size() | |||
| n = keep_prob + torch.rand(len(v)) | |||
| n = torch.floor(n).to(torch.bool) | |||
| i = i[:,n] | |||
| v = v[n] | |||
| x = _sparse_coo_tensor(i, v, size=size) | |||
| return x * (1./keep_prob) | |||
| def dropout_dense(x, keep_prob): | |||
| # print('dropout_dense()') | |||
| x = x.clone() | |||
| i = torch.nonzero(x) | |||
| n = keep_prob + torch.rand(len(i)) | |||
| n = (1. - torch.floor(n)).to(torch.bool) | |||
| x[i[n, 0], i[n, 1]] = 0. | |||
| return x * (1./keep_prob) | |||
| def dropout(x, keep_prob): | |||
| if x.is_sparse: | |||
| return dropout_sparse(x, keep_prob) | |||
| else: | |||
| return dropout_dense(x, keep_prob) | |||
| @@ -0,0 +1,255 @@ | |||
| from typing import List, \ | |||
| Union, \ | |||
| Callable | |||
| from .data import Data, \ | |||
| RelationFamily | |||
| from .trainprep import PreparedData, \ | |||
| PreparedRelationFamily | |||
| import torch | |||
| from .weights import init_glorot | |||
| from .normalize import _sparse_coo_tensor | |||
| import types | |||
| def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('The list of matrices must be non-empty') | |||
| if not all(m.is_sparse for m in matrices): | |||
| raise ValueError('All matrices must be sparse') | |||
| if not all(len(m.shape) == 2 for m in matrices): | |||
| raise ValueError('All matrices must be 2D') | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| col_offset = 0 | |||
| for m in matrices: | |||
| ind = m._indices().clone() | |||
| ind[0] += row_offset | |||
| ind[1] += col_offset | |||
| indices.append(ind) | |||
| values.append(m._values()) | |||
| row_offset += m.shape[0] | |||
| col_offset += m.shape[1] | |||
| indices = torch.cat(indices, dim=1) | |||
| values = torch.cat(values) | |||
| return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||
| def _cat(matrices: List[torch.Tensor]): | |||
| if len(matrices) == 0: | |||
| raise ValueError('Empty list passed to _cat()') | |||
| n = sum(a.is_sparse for a in matrices) | |||
| if n != 0 and n != len(matrices): | |||
| raise ValueError('All matrices must have the same layout (dense or sparse)') | |||
| if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||
| raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||
| if not matrices[0].is_sparse: | |||
| return torch.cat(matrices) | |||
| total_rows = sum(a.shape[0] for a in matrices) | |||
| indices = [] | |||
| values = [] | |||
| row_offset = 0 | |||
| for a in matrices: | |||
| ind = a._indices().clone() | |||
| val = a._values() | |||
| ind[0] += row_offset | |||
| ind = ind.transpose(0, 1) | |||
| indices.append(ind) | |||
| values.append(val) | |||
| row_offset += a.shape[0] | |||
| indices = torch.cat(indices).transpose(0, 1) | |||
| values = torch.cat(values) | |||
| res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
| return res | |||
| class FastGraphConv(torch.nn.Module): | |||
| def __init__(self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| adjacency_matrices: List[torch.Tensor], | |||
| keep_prob: float = 1., | |||
| activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| in_channels = int(in_channels) | |||
| out_channels = int(out_channels) | |||
| if not isinstance(adjacency_matrices, list): | |||
| raise TypeError('adjacency_matrices must be a list') | |||
| if len(adjacency_matrices) == 0: | |||
| raise ValueError('adjacency_matrices must not be empty') | |||
| if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): | |||
| raise TypeError('adjacency_matrices elements must be of class torch.Tensor') | |||
| if not all(m.is_sparse for m in adjacency_matrices): | |||
| raise ValueError('adjacency_matrices elements must be sparse') | |||
| keep_prob = float(keep_prob) | |||
| if not isinstance(activation, types.FunctionType): | |||
| raise TypeError('activation must be a function') | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.adjacency_matrices = adjacency_matrices | |||
| self.keep_prob = keep_prob | |||
| self.activation = activation | |||
| self.num_row_nodes = len(adjacency_matrices[0]) | |||
| self.num_relation_types = len(adjacency_matrices) | |||
| self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) | |||
| self.weights = torch.cat([ | |||
| init_glorot(in_channels, out_channels) \ | |||
| for _ in range(self.num_relation_types) | |||
| ], dim=1) | |||
| def forward(self, x) -> torch.Tensor: | |||
| if self.keep_prob < 1.: | |||
| x = dropout(x, self.keep_prob) | |||
| res = torch.sparse.mm(x, self.weights) \ | |||
| if x.is_sparse \ | |||
| else torch.mm(x, self.weights) | |||
| res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) | |||
| res = torch.cat(res) | |||
| res = torch.sparse.mm(self.adjacency_matrices, res) \ | |||
| if self.adjacency_matrices.is_sparse \ | |||
| else torch.mm(self.adjacency_matrices, res) | |||
| res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) | |||
| if self.activation is not None: | |||
| res = self.activation(res) | |||
| return res | |||
| class FastConvLayer(torch.nn.Module): | |||
| def __init__(self, | |||
| input_dim: List[int], | |||
| output_dim: List[int], | |||
| data: Union[Data, PreparedData], | |||
| keep_prob: float = 1., | |||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
| **kwargs): | |||
| super().__init__(**kwargs) | |||
| self._check_params(input_dim, output_dim, data, keep_prob, | |||
| rel_activation, layer_activation) | |||
| self.input_dim = input_dim | |||
| self.output_dim = output_dim | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.is_sparse = False | |||
| self.next_layer_repr = None | |||
| self.build() | |||
| def build(self): | |||
| self.next_layer_repr = torch.nn.ModuleList([ | |||
| torch.nn.ModuleList() \ | |||
| for _ in range(len(self.data.node_types)) | |||
| ]) | |||
| for fam in self.data.relation_families: | |||
| self.build_family(fam) | |||
| def build_family(self, fam) -> None: | |||
| if fam.node_type_row == fam.node_type_column: | |||
| self.build_fam_one_node_type(fam) | |||
| else: | |||
| self.build_fam_two_node_types(fam) | |||
| def build_fam_one_node_type(self, fam) -> None: | |||
| adjacency_matrices = [ | |||
| r.adjacency_matrix \ | |||
| for r in fam.relation_types | |||
| ] | |||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], | |||
| adjacency_matrices, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv.input_node_type = fam.node_type_column | |||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||
| def build_fam_two_node_types(self, fam) -> None: | |||
| adjacency_matrices = [ | |||
| r.adjacency_matrix \ | |||
| for r in fam.relation_types \ | |||
| if r.adjacency_matrix is not None | |||
| ] | |||
| adjacency_matrices_backward = [ | |||
| r.adjacency_matrix_backward \ | |||
| for r in fam.relation_types \ | |||
| if r.adjacency_matrix_backward is not None | |||
| ] | |||
| conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], | |||
| adjacency_matrices, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||
| self.output_dim[fam.node_type_column], | |||
| adjacency_matrices_backward, | |||
| self.keep_prob, | |||
| self.rel_activation) | |||
| conv.input_node_type = fam.node_type_column | |||
| conv_backward.input_node_type = fam.node_type_row | |||
| self.next_layer_repr[fam.node_type_row].append(conv) | |||
| self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||
| def forward(self, prev_layer_repr): | |||
| next_layer_repr = [ [] \ | |||
| for _ in range(len(self.data.node_types)) ] | |||
| for output_node_type in range(len(self.data.node_types)): | |||
| for conv in self.next_layer_repr[output_node_type]: | |||
| rep = conv(prev_layer_repr[conv.input_node_type]) | |||
| rep = torch.sum(rep, dim=0) | |||
| rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||
| next_layer_repr[output_node_type].append(rep) | |||
| if len(next_layer_repr[output_node_type]) == 0: | |||
| next_layer_repr[output_node_type] = \ | |||
| torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||
| else: | |||
| next_layer_repr[output_node_type] = \ | |||
| sum(next_layer_repr[output_node_type]) | |||
| next_layer_repr[output_node_type] = \ | |||
| self.layer_activation(next_layer_repr[output_node_type]) | |||
| return next_layer_repr | |||
| @staticmethod | |||
| def _check_params(input_dim, output_dim, data, keep_prob, | |||
| rel_activation, layer_activation): | |||
| if not isinstance(input_dim, list): | |||
| raise ValueError('input_dim must be a list') | |||
| if not output_dim: | |||
| raise ValueError('output_dim must be specified') | |||
| if not isinstance(output_dim, list): | |||
| output_dim = [output_dim] * len(data.node_types) | |||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||
| raise ValueError('data must be of type Data or PreparedData') | |||
| @@ -0,0 +1,138 @@ | |||
| import torch | |||
| from typing import List | |||
| from .trainprep import PreparedData | |||
| from dataclasses import dataclass | |||
| import random | |||
| from collections import defaultdict | |||
| @dataclass | |||
| class TrainingBatch(object): | |||
| relation_family_index: int | |||
| relation_type_index: int | |||
| node_type_row: int | |||
| node_type_column: int | |||
| edges: torch.Tensor | |||
| class FastBatcher(object): | |||
| def __init__(self, | |||
| prep_d: PreparedData, | |||
| batch_size: int) -> None: | |||
| if not isinstance(prep_d, PreparedData): | |||
| raise TypeError('prep_d must be an instance of PreparedData') | |||
| self.prep_d = prep_d | |||
| self.batch_size = int(batch_size) | |||
| self.edges = None | |||
| self.build() | |||
| def build(self): | |||
| self.edges = [] | |||
| for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||
| edges = [] | |||
| targets = [] | |||
| edges_back = [] | |||
| targets_back = [] | |||
| for rel_idx, rel in enumerate(fam.relation_types): | |||
| edges.append(rel.edges_pos.train) | |||
| edges.append(rel.edges_neg.train) | |||
| targets.append(torch.ones(len(rel.edges_pos.train))) | |||
| targets.append(torch.zeros(len(rel.edges_neg.train))) | |||
| edges_back.append(rel.edges_back_pos.train) | |||
| edges_back.append(rel.edges_back_neg.train) | |||
| targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) | |||
| targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) | |||
| edges = torch.cat(edges) | |||
| targets = torch.cat(targets) | |||
| edges_back = torch.cat(edges_back) | |||
| targets_back = torch.cat(targets_back) | |||
| order = torch.randperm(len(edges)) | |||
| edges = edges[order] | |||
| targets = targets[order] | |||
| order_back = torch.randperm(len(edges_back)) | |||
| edges_back = edges_back[order_back] | |||
| targets_back = targets_back[order_back] | |||
| self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, | |||
| 'edges': edges, 'targets': targets, 'ofs': 0}) | |||
| self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, | |||
| 'edges': edges_back, 'targets': targets_back, 'ofs': 0}) | |||
| def __iter__(self): | |||
| while True: | |||
| edges = [ e for e in self.edges \ | |||
| if e['ofs'] < len(e['edges']) ] | |||
| # TODO: need to finish this | |||
| def __iter_old__(self): | |||
| edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] | |||
| offsets = {} | |||
| orders = {} | |||
| done = {} | |||
| for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||
| for rel_idx, rel in enumerate(fam.relation_types): | |||
| for et in edge_types: | |||
| done[fam_idx, rel_idx, et] = False | |||
| while True: | |||
| fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() | |||
| fam = self.prep_d.relation_families[fam_idx] | |||
| rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() | |||
| rel = fam.relation_types[rel_idx] | |||
| et = random.choice(edge_types) | |||
| edges = getattr(rel, et).train | |||
| key = (fam_idx, rel_idx, et) | |||
| if key not in orders: | |||
| orders[key] = torch.randperm(len(edges)) | |||
| offsets[key] = 0 | |||
| ord = orders[key] | |||
| ofs = offsets[key] | |||
| nt_row = rel.node_type_row | |||
| nt_col = rel.node_type_column | |||
| if 'back' in et: | |||
| nt_row, nt_col = nt_col, nt_row | |||
| if ofs < len(edges): | |||
| offsets[key] += self.batch_size | |||
| ord = ord[ofs:ofs+self.batch_size] | |||
| edges = edges[ord] | |||
| yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) | |||
| else: | |||
| done[key] = True | |||
| for fam in self.prep_d.relation_families: | |||
| edges = [] | |||
| for rel in fam.relation_types: | |||
| edges.append(rel.edges_pos.train) | |||
| edges.append(rel.edges_back_pos.train) | |||
| edges.append(rel.edges_neg.train) | |||
| edges.append(rel.edges_back_neg.train) | |||
| edges = torch.cat(e) | |||
| class FastDecLayer(torch.nn.Module): | |||
| def __init__(self, **kwargs): | |||
| super().__init__(**kwargs) | |||
| def forward(self, | |||
| last_layer_repr: List[torch.Tensor], | |||
| training_batch: TrainingBatch): | |||
| @@ -0,0 +1,166 @@ | |||
| from .fastmodel import FastModel | |||
| from .trainprep import PreparedData | |||
| import torch | |||
| from typing import Callable | |||
| from types import FunctionType | |||
| import time | |||
| import random | |||
| class FastBatcher(object): | |||
| def __init__(self, prep_d: PreparedData, batch_size: int, | |||
| shuffle: bool, generator: torch.Generator, | |||
| part_type: str) -> None: | |||
| if not isinstance(prep_d, PreparedData): | |||
| raise TypeError('prep_d must be an instance of PreparedData') | |||
| if not isinstance(generator, torch.Generator): | |||
| raise TypeError('generator must be an instance of torch.Generator') | |||
| if part_type not in ['train', 'val', 'test']: | |||
| raise ValueError('part_type must be set to train, val or test') | |||
| self.prep_d = prep_d | |||
| self.batch_size = int(batch_size) | |||
| self.shuffle = bool(shuffle) | |||
| self.generator = generator | |||
| self.part_type = part_type | |||
| self.edges = None | |||
| self.targets = None | |||
| self.build() | |||
| def build(self): | |||
| self.edges = [] | |||
| self.targets = [] | |||
| for fam in self.prep_d.relation_families: | |||
| edges = [] | |||
| targets = [] | |||
| for i, rel in enumerate(fam.relation_types): | |||
| edges_pos = getattr(rel.edges_pos, self.part_type) | |||
| edges_neg = getattr(rel.edges_neg, self.part_type) | |||
| edges_back_pos = getattr(rel.edges_back_pos, self.part_type) | |||
| edges_back_neg = getattr(rel.edges_back_neg, self.part_type) | |||
| e = torch.cat([ edges_pos, | |||
| torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) | |||
| e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) | |||
| t = torch.ones(len(e)) | |||
| edges.append(e) | |||
| targets.append(t) | |||
| e = torch.cat([ edges_neg, | |||
| torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) | |||
| e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) | |||
| t = torch.zeros(len(e)) | |||
| edges.append(e) | |||
| targets.append(t) | |||
| edges = torch.cat(edges) | |||
| targets = torch.cat(targets) | |||
| self.edges.append(edges) | |||
| self.targets.append(targets) | |||
| # print(self.edges) | |||
| # print(self.targets) | |||
| if self.shuffle: | |||
| self.shuffle_families() | |||
| def shuffle_families(self): | |||
| for i in range(len(self.edges)): | |||
| edges = self.edges[i] | |||
| targets = self.targets[i] | |||
| order = torch.randperm(len(edges), generator=self.generator) | |||
| self.edges[i] = edges[order] | |||
| self.targets[i] = targets[order] | |||
| def __iter__(self): | |||
| offsets = [ 0 for _ in self.edges ] | |||
| while True: | |||
| choice = [ i for i in range(len(offsets)) \ | |||
| if offsets[i] < len(self.edges[i]) ] | |||
| if len(choice) == 0: | |||
| break | |||
| fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() | |||
| ofs = offsets[fam_idx] | |||
| edges = self.edges[fam_idx][ofs:ofs + self.batch_size] | |||
| targets = self.targets[fam_idx][ofs:ofs + self.batch_size] | |||
| offsets[fam_idx] += self.batch_size | |||
| yield (fam_idx, edges, targets) | |||
| class FastLoop(object): | |||
| def __init__( | |||
| self, | |||
| model: FastModel, | |||
| lr: float = 0.001, | |||
| loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||
| torch.nn.functional.binary_cross_entropy_with_logits, | |||
| batch_size: int = 100, | |||
| shuffle: bool = True, | |||
| generator: torch.Generator = None) -> None: | |||
| self._check_params(model, loss, generator) | |||
| self.model = model | |||
| self.lr = float(lr) | |||
| self.loss = loss | |||
| self.batch_size = int(batch_size) | |||
| self.shuffle = bool(shuffle) | |||
| self.generator = generator or torch.default_generator | |||
| self.opt = None | |||
| self.build() | |||
| def _check_params(self, model, loss, generator): | |||
| if not isinstance(model, FastModel): | |||
| raise TypeError('model must be an instance of FastModel') | |||
| if not isinstance(loss, FunctionType): | |||
| raise TypeError('loss must be a function') | |||
| if generator is not None and not isinstance(generator, torch.Generator): | |||
| raise TypeError('generator must be an instance of torch.Generator') | |||
| def build(self) -> None: | |||
| opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |||
| self.opt = opt | |||
| def run_epoch(self): | |||
| prep_d = self.model.prep_d | |||
| batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, | |||
| shuffle = self.shuffle, generator=self.generator) | |||
| # pred = self.model(None) | |||
| # n = len(list(iter(batch))) | |||
| loss_sum = 0 | |||
| for fam_idx, edges, targets in batcher: | |||
| self.opt.zero_grad() | |||
| pred = self.model(None) | |||
| # process pred, get input and targets | |||
| input = pred[fam_idx][edges[:, 0], edges[:, 1]] | |||
| loss = self.loss(input, targets) | |||
| loss.backward() | |||
| self.opt.step() | |||
| loss_sum += loss.detach().cpu().item() | |||
| return loss_sum | |||
| def train(self, max_epochs): | |||
| best_loss = None | |||
| best_epoch = None | |||
| for i in range(max_epochs): | |||
| loss = self.run_epoch() | |||
| if best_loss is None or loss < best_loss: | |||
| best_loss = loss | |||
| best_epoch = i | |||
| return loss, best_loss, best_epoch | |||
| @@ -0,0 +1,79 @@ | |||
| from .fastconv import FastConvLayer | |||
| from .bulkdec import BulkDecodeLayer | |||
| from .input import OneHotInputLayer | |||
| from .trainprep import PreparedData | |||
| import torch | |||
| import types | |||
| from typing import List, \ | |||
| Union, \ | |||
| Callable | |||
| class FastModel(torch.nn.Module): | |||
| def __init__(self, prep_d: PreparedData, | |||
| layer_dimensions: List[int] = [32, 64], | |||
| keep_prob: float = 1., | |||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
| dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
| **kwargs) -> None: | |||
| super().__init__(**kwargs) | |||
| self._check_params(prep_d, layer_dimensions, rel_activation, | |||
| layer_activation, dec_activation) | |||
| self.prep_d = prep_d | |||
| self.layer_dimensions = layer_dimensions | |||
| self.keep_prob = float(keep_prob) | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.dec_activation = dec_activation | |||
| self.seq = None | |||
| self.build() | |||
| def build(self): | |||
| in_layer = OneHotInputLayer(self.prep_d) | |||
| last_output_dim = in_layer.output_dim | |||
| seq = [ in_layer ] | |||
| for dim in self.layer_dimensions: | |||
| conv_layer = FastConvLayer(input_dim = last_output_dim, | |||
| output_dim = [dim] * len(self.prep_d.node_types), | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| rel_activation = self.rel_activation, | |||
| layer_activation = self.layer_activation) | |||
| last_output_dim = conv_layer.output_dim | |||
| seq.append(conv_layer) | |||
| dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||
| data = self.prep_d, | |||
| keep_prob = self.keep_prob, | |||
| activation = self.dec_activation) | |||
| seq.append(dec_layer) | |||
| seq = torch.nn.Sequential(*seq) | |||
| self.seq = seq | |||
| def forward(self, _): | |||
| return self.seq(None) | |||
| def _check_params(self, prep_d, layer_dimensions, rel_activation, | |||
| layer_activation, dec_activation): | |||
| if not isinstance(prep_d, PreparedData): | |||
| raise TypeError('prep_d must be an instanced of PreparedData') | |||
| if not isinstance(layer_dimensions, list): | |||
| raise TypeError('layer_dimensions must be a list') | |||
| if not isinstance(rel_activation, types.FunctionType): | |||
| raise TypeError('rel_activation must be a function') | |||
| if not isinstance(layer_activation, types.FunctionType): | |||
| raise TypeError('layer_activation must be a function') | |||
| if not isinstance(dec_activation, types.FunctionType): | |||
| raise TypeError('dec_activation must be a function') | |||
| @@ -0,0 +1,79 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| from typing import Union, \ | |||
| List | |||
| from .data import Data | |||
| class InputLayer(torch.nn.Module): | |||
| def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, | |||
| **kwargs) -> None: | |||
| output_dim = output_dim or \ | |||
| list(map(lambda a: a.count, data.node_types)) | |||
| if not isinstance(output_dim, list): | |||
| output_dim = [output_dim,] * len(data.node_types) | |||
| super().__init__(**kwargs) | |||
| self.output_dim = output_dim | |||
| self.data = data | |||
| self.is_sparse=False | |||
| self.node_reps = None | |||
| self.build() | |||
| def build(self) -> None: | |||
| self.node_reps = [] | |||
| for i, nt in enumerate(self.data.node_types): | |||
| reps = torch.rand(nt.count, self.output_dim[i]) | |||
| reps = torch.nn.Parameter(reps) | |||
| self.register_parameter('node_reps[%d]' % i, reps) | |||
| self.node_reps.append(reps) | |||
| def forward(self, x) -> List[torch.nn.Parameter]: | |||
| return self.node_reps | |||
| def __repr__(self) -> str: | |||
| s = '' | |||
| s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim | |||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||
| for nt in self.data.node_types: | |||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
| return s.strip() | |||
| class OneHotInputLayer(torch.nn.Module): | |||
| def __init__(self, data: Data, **kwargs) -> None: | |||
| output_dim = [ a.count for a in data.node_types ] | |||
| super().__init__(**kwargs) | |||
| self.output_dim = output_dim | |||
| self.data = data | |||
| self.is_sparse=True | |||
| self.node_reps = None | |||
| self.build() | |||
| def build(self) -> None: | |||
| self.node_reps = torch.nn.ParameterList() | |||
| for i, nt in enumerate(self.data.node_types): | |||
| reps = torch.eye(nt.count).to_sparse() | |||
| reps = torch.nn.Parameter(reps, requires_grad=False) | |||
| # self.register_parameter('node_reps[%d]' % i, reps) | |||
| self.node_reps.append(reps) | |||
| def forward(self, x) -> List[torch.nn.Parameter]: | |||
| return self.node_reps | |||
| def __repr__(self) -> str: | |||
| s = '' | |||
| s += 'Icosagon one-hot input layer\n' | |||
| s += ' # of node types: %d\n' % len(self.data.node_types) | |||
| for nt in self.data.node_types: | |||
| s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
| return s.strip() | |||
| @@ -0,0 +1,145 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import numpy as np | |||
| import scipy.sparse as sp | |||
| import torch | |||
| def _check_tensor(adj_mat): | |||
| if not isinstance(adj_mat, torch.Tensor): | |||
| raise ValueError('adj_mat must be a torch.Tensor') | |||
| def _check_sparse(adj_mat): | |||
| if not adj_mat.is_sparse: | |||
| raise ValueError('adj_mat must be sparse') | |||
| def _check_dense(adj_mat): | |||
| if adj_mat.is_sparse: | |||
| raise ValueError('adj_mat must be dense') | |||
| def _check_square(adj_mat): | |||
| if len(adj_mat.shape) != 2 or \ | |||
| adj_mat.shape[0] != adj_mat.shape[1]: | |||
| raise ValueError('adj_mat must be a square matrix') | |||
| def _check_2d(adj_mat): | |||
| if len(adj_mat.shape) != 2: | |||
| raise ValueError('adj_mat must be a square matrix') | |||
| def _sparse_coo_tensor(indices, values, size): | |||
| ctor = { torch.float32: torch.sparse.FloatTensor, | |||
| torch.float32: torch.sparse.DoubleTensor, | |||
| torch.uint8: torch.sparse.ByteTensor, | |||
| torch.long: torch.sparse.LongTensor, | |||
| torch.int: torch.sparse.IntTensor, | |||
| torch.short: torch.sparse.ShortTensor, | |||
| torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||
| return ctor(indices, values, size) | |||
| def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_sparse(adj_mat) | |||
| _check_square(adj_mat) | |||
| adj_mat = adj_mat.coalesce() | |||
| indices = adj_mat.indices() | |||
| values = adj_mat.values() | |||
| eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, | |||
| device=adj_mat.device).view(1, -1) | |||
| eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||
| eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, | |||
| device=adj_mat.device) | |||
| indices = torch.cat((indices, eye_indices), 1) | |||
| values = torch.cat((values, eye_values), 0) | |||
| adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||
| return adj_mat | |||
| def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_sparse(adj_mat) | |||
| _check_square(adj_mat) | |||
| adj_mat = add_eye_sparse(adj_mat) | |||
| adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat) | |||
| return adj_mat | |||
| def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_dense(adj_mat) | |||
| _check_square(adj_mat) | |||
| adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, | |||
| device=adj_mat.device) | |||
| adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) | |||
| return adj_mat | |||
| def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_square(adj_mat) | |||
| if adj_mat.is_sparse: | |||
| return norm_adj_mat_one_node_type_sparse(adj_mat) | |||
| else: | |||
| return norm_adj_mat_one_node_type_dense(adj_mat) | |||
| def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_sparse(adj_mat) | |||
| _check_2d(adj_mat) | |||
| adj_mat = adj_mat.coalesce() | |||
| indices = adj_mat.indices() | |||
| values = adj_mat.values() | |||
| degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) | |||
| degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) | |||
| degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) | |||
| degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) | |||
| values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) | |||
| adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||
| return adj_mat | |||
| def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_dense(adj_mat) | |||
| _check_2d(adj_mat) | |||
| degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32) | |||
| degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32) | |||
| degrees_row = torch.sqrt(degrees_row) | |||
| degrees_col = torch.sqrt(degrees_col) | |||
| adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row | |||
| adj_mat = adj_mat / degrees_col | |||
| return adj_mat | |||
| def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_tensor(adj_mat) | |||
| _check_2d(adj_mat) | |||
| if adj_mat.is_sparse: | |||
| return norm_adj_mat_two_node_types_sparse(adj_mat) | |||
| else: | |||
| return norm_adj_mat_two_node_types_dense(adj_mat) | |||
| @@ -0,0 +1,47 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.data | |||
| from typing import List, \ | |||
| Union | |||
| def fixed_unigram_candidate_sampler( | |||
| true_classes: Union[np.array, torch.Tensor], | |||
| unigrams: List[Union[int, float]], | |||
| distortion: float = 1.): | |||
| if isinstance(true_classes, torch.Tensor): | |||
| true_classes = true_classes.detach().cpu().numpy() | |||
| if isinstance(unigrams, torch.Tensor): | |||
| unigrams = unigrams.detach().cpu().numpy() | |||
| if len(true_classes.shape) != 2: | |||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||
| num_samples = true_classes.shape[0] | |||
| unigrams = np.array(unigrams) | |||
| if distortion != 1.: | |||
| unigrams = unigrams.astype(np.float64) ** distortion | |||
| # print('unigrams:', unigrams) | |||
| indices = np.arange(num_samples) | |||
| result = np.zeros(num_samples, dtype=np.int64) | |||
| while len(indices) > 0: | |||
| # print('len(indices):', len(indices)) | |||
| sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||
| candidates = np.array(list(sampler)) | |||
| candidates = np.reshape(candidates, (len(indices), 1)) | |||
| # print('candidates:', candidates) | |||
| # print('true_classes:', true_classes[indices, :]) | |||
| result[indices] = candidates.T | |||
| mask = (candidates == true_classes[indices, :]) | |||
| mask = mask.sum(1).astype(np.bool) | |||
| # print('mask:', mask) | |||
| indices = indices[mask] | |||
| return torch.tensor(result) | |||
| @@ -0,0 +1,215 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| from .sampling import fixed_unigram_candidate_sampler | |||
| import torch | |||
| from dataclasses import dataclass, \ | |||
| field | |||
| from typing import Any, \ | |||
| List, \ | |||
| Tuple, \ | |||
| Dict | |||
| from .data import NodeType, \ | |||
| RelationType, \ | |||
| RelationTypeBase, \ | |||
| RelationFamily, \ | |||
| RelationFamilyBase, \ | |||
| Data | |||
| from collections import defaultdict | |||
| from .normalize import norm_adj_mat_one_node_type, \ | |||
| norm_adj_mat_two_node_types | |||
| import numpy as np | |||
| @dataclass | |||
| class TrainValTest(object): | |||
| train: Any | |||
| val: Any | |||
| test: Any | |||
| @dataclass | |||
| class PreparedRelationType(RelationTypeBase): | |||
| edges_pos: TrainValTest | |||
| edges_neg: TrainValTest | |||
| edges_back_pos: TrainValTest | |||
| edges_back_neg: TrainValTest | |||
| @dataclass | |||
| class PreparedRelationFamily(RelationFamilyBase): | |||
| relation_types: List[PreparedRelationType] | |||
| @dataclass | |||
| class PreparedData(object): | |||
| node_types: List[NodeType] | |||
| relation_families: List[PreparedRelationFamily] | |||
| def _empty_edge_list_tvt() -> TrainValTest: | |||
| return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||
| def train_val_test_split_edges(edges: torch.Tensor, | |||
| ratios: TrainValTest) -> TrainValTest: | |||
| if not isinstance(edges, torch.Tensor): | |||
| raise ValueError('edges must be a torch.Tensor') | |||
| if len(edges.shape) != 2 or edges.shape[1] != 2: | |||
| raise ValueError('edges shape must be (num_edges, 2)') | |||
| if not isinstance(ratios, TrainValTest): | |||
| raise ValueError('ratios must be a TrainValTest') | |||
| if ratios.train + ratios.val + ratios.test != 1.0: | |||
| raise ValueError('Train, validation and test ratios must add up to 1') | |||
| order = torch.randperm(len(edges)) | |||
| edges = edges[order, :] | |||
| n = round(len(edges) * ratios.train) | |||
| edges_train = edges[:n] | |||
| n_1 = round(len(edges) * (ratios.train + ratios.val)) | |||
| edges_val = edges[n:n_1] | |||
| edges_test = edges[n_1:] | |||
| return TrainValTest(edges_train, edges_val, edges_test) | |||
| def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| if adj_mat.is_sparse: | |||
| adj_mat = adj_mat.coalesce() | |||
| degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, | |||
| device=adj_mat.device) | |||
| degrees = degrees.index_add(0, adj_mat.indices()[1], | |||
| torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, | |||
| device=adj_mat.device)) | |||
| edges_pos = adj_mat.indices().transpose(0, 1) | |||
| else: | |||
| degrees = adj_mat.sum(0) | |||
| edges_pos = torch.nonzero(adj_mat) | |||
| return edges_pos, degrees | |||
| def prepare_adj_mat(adj_mat: torch.Tensor, | |||
| ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: | |||
| if not isinstance(adj_mat, torch.Tensor): | |||
| raise ValueError('adj_mat must be a torch.Tensor') | |||
| edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
| neg_neighbors = fixed_unigram_candidate_sampler( | |||
| edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||
| print(edges_pos.dtype) | |||
| print(neg_neighbors.dtype) | |||
| edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) | |||
| edges_pos = train_val_test_split_edges(edges_pos, ratios) | |||
| edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||
| adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | |||
| values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, | |||
| device=adj_mat.device) | |||
| return adj_mat_train, edges_pos, edges_neg | |||
| def prep_rel_one_node_type(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| print('adj_mat_train:', adj_mat_train) | |||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
| adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||
| edges_back_pos, edges_back_neg) | |||
| def prep_rel_two_node_types_sym(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| edges_back_pos, edges_back_neg = \ | |||
| _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
| def prep_rel_two_node_types_asym(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| if r.adjacency_matrix is not None: | |||
| adj_mat_train, edges_pos, edges_neg =\ | |||
| prepare_adj_mat(r.adjacency_matrix, ratios) | |||
| else: | |||
| adj_mat_train, edges_pos, edges_neg = \ | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| if r.adjacency_matrix_backward is not None: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||
| else: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_back_train), | |||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
| def prepare_relation_type(r: RelationType, | |||
| ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||
| if not isinstance(r, RelationType): | |||
| raise ValueError('r must be a RelationType') | |||
| if not isinstance(ratios, TrainValTest): | |||
| raise ValueError('ratios must be a TrainValTest') | |||
| if r.node_type_row == r.node_type_column: | |||
| return prep_rel_one_node_type(r, ratios) | |||
| elif is_symmetric: | |||
| return prep_rel_two_node_types_sym(r, ratios) | |||
| else: | |||
| return prep_rel_two_node_types_asym(r, ratios) | |||
| def prepare_relation_family(fam: RelationFamily, | |||
| ratios: TrainValTest) -> PreparedRelationFamily: | |||
| relation_types = [] | |||
| for r in fam.relation_types: | |||
| relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||
| return PreparedRelationFamily(fam.data, fam.name, | |||
| fam.node_type_row, fam.node_type_column, | |||
| fam.is_symmetric, fam.decoder_class, | |||
| relation_types) | |||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
| if not isinstance(data, Data): | |||
| raise ValueError('data must be of class Data') | |||
| relation_families = [ prepare_relation_family(fam, ratios) \ | |||
| for fam in data.relation_families ] | |||
| return PreparedData(data.node_types, relation_families) | |||
| @@ -0,0 +1,19 @@ | |||
| # | |||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||
| # License: GPLv3 | |||
| # | |||
| import torch | |||
| import numpy as np | |||
| def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||
| """Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||
| initialization. | |||
| """ | |||
| init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||
| initial = -init_range + 2 * init_range * \ | |||
| torch.rand(( in_channels, out_channels ), dtype=dtype) | |||
| initial = initial.requires_grad_(True) | |||
| return initial | |||