diff --git a/src/triacontagon/data.py b/src/triacontagon/data.py new file mode 100644 index 0000000..4505adf --- /dev/null +++ b/src/triacontagon/data.py @@ -0,0 +1,209 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +from collections import defaultdict +from dataclasses import dataclass, field +import torch +from typing import List, \ + Dict, \ + Tuple, \ + Any, \ + Type +from .decode import DEDICOMDecoder, \ + BilinearDecoder +import numpy as np + + +def _equal(x: torch.Tensor, y: torch.Tensor): + if x.is_sparse ^ y.is_sparse: + raise ValueError('Cannot mix sparse and dense tensors') + + if not x.is_sparse: + return (x == y) + + return ((x - y).coalesce().values() == 0) + + +@dataclass +class NodeType(object): + name: str + count: int + + +@dataclass +class RelationTypeBase(object): + name: str + node_type_row: int + node_type_column: int + adjacency_matrix: torch.Tensor + adjacency_matrix_backward: torch.Tensor + + +@dataclass +class RelationType(RelationTypeBase): + pass + + +@dataclass +class RelationFamilyBase(object): + data: 'Data' + name: str + node_type_row: int + node_type_column: int + is_symmetric: bool + decoder_class: Type + + +@dataclass +class RelationFamily(RelationFamilyBase): + relation_types: List[RelationType] = None + + def __post_init__(self) -> None: + if not self.is_symmetric and \ + self.decoder_class != DEDICOMDecoder and \ + self.decoder_class != BilinearDecoder: + raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') + + self.relation_types = [] + + def add_relation_type(self, + name: str, adjacency_matrix: torch.Tensor, + adjacency_matrix_backward: torch.Tensor = None) -> None: + + name = str(name) + node_type_row = self.node_type_row + node_type_column = self.node_type_column + + if adjacency_matrix is None and adjacency_matrix_backward is None: + raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') + + if adjacency_matrix is not None and \ + not isinstance(adjacency_matrix, torch.Tensor): + raise ValueError('adjacency_matrix must be a torch.Tensor') + + if adjacency_matrix_backward is not None \ + and not isinstance(adjacency_matrix_backward, torch.Tensor): + raise ValueError('adjacency_matrix_backward must be a torch.Tensor') + + if adjacency_matrix is not None and \ + adjacency_matrix.shape != (self.data.node_types[node_type_row].count, + self.data.node_types[node_type_column].count): + raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') + + if adjacency_matrix_backward is not None and \ + adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, + self.data.node_types[node_type_row].count): + raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') + + if node_type_row == node_type_column and \ + adjacency_matrix_backward is not None: + raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') + + if self.is_symmetric and adjacency_matrix_backward is not None: + raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') + + if self.is_symmetric and node_type_row == node_type_column and \ + not torch.all(_equal(adjacency_matrix, + adjacency_matrix.transpose(0, 1))): + raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') + + if not self.is_symmetric and node_type_row != node_type_column and \ + adjacency_matrix_backward is None: + raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') + + if self.is_symmetric and node_type_row != node_type_column: + adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) + + self.relation_types.append(RelationType(name, + node_type_row, node_type_column, + adjacency_matrix, adjacency_matrix_backward)) + + def node_name(self, index): + return self.data.node_types[index].name + + def __repr__(self): + s = 'Relation family %s' % self.name + + for r in self.relation_types: + s += '\n - %s%s' % (r.name, ' (two-way)' \ + if (r.adjacency_matrix is not None \ + and r.adjacency_matrix_backward is not None) \ + or self.node_type_row == self.node_type_column \ + else '%s <- %s' % (self.node_name(self.node_type_row), + self.node_name(self.node_type_column))) + + return s + + def repr_indented(self): + s = ' - %s' % self.name + + for r in self.relation_types: + s += '\n - %s%s' % (r.name, ' (two-way)' \ + if (r.adjacency_matrix is not None \ + and r.adjacency_matrix_backward is not None) \ + or self.node_type_row == self.node_type_column \ + else '%s <- %s' % (self.node_name(self.node_type_row), + self.node_name(self.node_type_column))) + + return s + + +class Data(object): + node_types: List[NodeType] + relation_families: List[RelationFamily] + + def __init__(self) -> None: + self.node_types = [] + self.relation_families = [] + + def add_node_type(self, name: str, count: int) -> None: + name = str(name) + count = int(count) + if not name: + raise ValueError('You must provide a non-empty node type name') + if count <= 0: + raise ValueError('You must provide a positive node count') + self.node_types.append(NodeType(name, count)) + + def add_relation_family(self, name: str, node_type_row: int, + node_type_column: int, is_symmetric: bool, + decoder_class: Type = DEDICOMDecoder): + + name = str(name) + node_type_row = int(node_type_row) + node_type_column = int(node_type_column) + is_symmetric = bool(is_symmetric) + + if node_type_row < 0 or node_type_row >= len(self.node_types): + raise ValueError('node_type_row outside of the valid range of node types') + + if node_type_column < 0 or node_type_column >= len(self.node_types): + raise ValueError('node_type_column outside of the valid range of node types') + + fam = RelationFamily(self, name, node_type_row, node_type_column, + is_symmetric, decoder_class) + self.relation_families.append(fam) + + return fam + + def __repr__(self): + n = len(self.node_types) + if n == 0: + return 'Empty Icosagon Data' + s = '' + s += 'Icosagon Data with:\n' + s += '- ' + str(n) + ' node type(s):\n' + for nt in self.node_types: + s += ' - ' + nt.name + '\n' + if len(self.relation_families) == 0: + s += '- No relation families\n' + return s.strip() + + s += '- %d relation families:\n' % len(self.relation_families) + for fam in self.relation_families: + s += fam.repr_indented() + '\n' + + return s.strip() diff --git a/src/triacontagon/decode.py b/src/triacontagon/decode.py new file mode 100644 index 0000000..00df8b2 --- /dev/null +++ b/src/triacontagon/decode.py @@ -0,0 +1,123 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .weights import init_glorot +from .dropout import dropout + + +class DEDICOMDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, keep_prob=1., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.keep_prob = keep_prob + self.activation = activation + + self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) + self.local_variation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ + for _ in range(num_relation_types) + ]) + + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + relation = torch.diag(self.local_variation[relation_index]) + + product1 = torch.mm(inputs_row, relation) + product2 = torch.mm(product1, self.global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) + + +class DistMultDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, keep_prob=1., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.keep_prob = keep_prob + self.activation = activation + + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ + for _ in range(num_relation_types) + ]) + + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + relation = torch.diag(self.relation[relation_index]) + + intermediate_product = torch.mm(inputs_row, relation) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) + + +class BilinearDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, keep_prob=1., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.keep_prob = keep_prob + self.activation = activation + + self.relation = torch.nn.ParameterList([ + torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ + for _ in range(num_relation_types) + ]) + + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) + + +class InnerProductDecoder(torch.nn.Module): + """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" + def __init__(self, input_dim, num_relation_types, keep_prob=1., + activation=torch.sigmoid, **kwargs): + + super().__init__(**kwargs) + self.input_dim = input_dim + self.num_relation_types = num_relation_types + self.keep_prob = keep_prob + self.activation = activation + + + def forward(self, inputs_row, inputs_col, _): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) diff --git a/src/triacontagon/dropout.py b/src/triacontagon/dropout.py new file mode 100644 index 0000000..63cfb58 --- /dev/null +++ b/src/triacontagon/dropout.py @@ -0,0 +1,42 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from .normalize import _sparse_coo_tensor + + +def dropout_sparse(x, keep_prob): + x = x.coalesce() + i = x._indices() + v = x._values() + size = x.size() + + n = keep_prob + torch.rand(len(v)) + n = torch.floor(n).to(torch.bool) + i = i[:,n] + v = v[n] + x = _sparse_coo_tensor(i, v, size=size) + + return x * (1./keep_prob) + + +def dropout_dense(x, keep_prob): + # print('dropout_dense()') + x = x.clone() + i = torch.nonzero(x) + + n = keep_prob + torch.rand(len(i)) + n = (1. - torch.floor(n)).to(torch.bool) + x[i[n, 0], i[n, 1]] = 0. + + return x * (1./keep_prob) + + +def dropout(x, keep_prob): + if x.is_sparse: + return dropout_sparse(x, keep_prob) + else: + return dropout_dense(x, keep_prob) diff --git a/src/triacontagon/fastconv.py b/src/triacontagon/fastconv.py new file mode 100644 index 0000000..038e2fc --- /dev/null +++ b/src/triacontagon/fastconv.py @@ -0,0 +1,255 @@ +from typing import List, \ + Union, \ + Callable +from .data import Data, \ + RelationFamily +from .trainprep import PreparedData, \ + PreparedRelationFamily +import torch +from .weights import init_glorot +from .normalize import _sparse_coo_tensor +import types + + +def _sparse_diag_cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('The list of matrices must be non-empty') + + if not all(m.is_sparse for m in matrices): + raise ValueError('All matrices must be sparse') + + if not all(len(m.shape) == 2 for m in matrices): + raise ValueError('All matrices must be 2D') + + indices = [] + values = [] + row_offset = 0 + col_offset = 0 + + for m in matrices: + ind = m._indices().clone() + ind[0] += row_offset + ind[1] += col_offset + indices.append(ind) + values.append(m._values()) + row_offset += m.shape[0] + col_offset += m.shape[1] + + indices = torch.cat(indices, dim=1) + values = torch.cat(values) + + return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) + + +def _cat(matrices: List[torch.Tensor]): + if len(matrices) == 0: + raise ValueError('Empty list passed to _cat()') + + n = sum(a.is_sparse for a in matrices) + if n != 0 and n != len(matrices): + raise ValueError('All matrices must have the same layout (dense or sparse)') + + if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): + raise ValueError('All matrices must have the same dimensions apart from dimension 0') + + if not matrices[0].is_sparse: + return torch.cat(matrices) + + total_rows = sum(a.shape[0] for a in matrices) + indices = [] + values = [] + row_offset = 0 + + for a in matrices: + ind = a._indices().clone() + val = a._values() + ind[0] += row_offset + ind = ind.transpose(0, 1) + indices.append(ind) + values.append(val) + row_offset += a.shape[0] + + indices = torch.cat(indices).transpose(0, 1) + values = torch.cat(values) + + res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) + return res + + +class FastGraphConv(torch.nn.Module): + def __init__(self, + in_channels: int, + out_channels: int, + adjacency_matrices: List[torch.Tensor], + keep_prob: float = 1., + activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + **kwargs) -> None: + + super().__init__(**kwargs) + + in_channels = int(in_channels) + out_channels = int(out_channels) + if not isinstance(adjacency_matrices, list): + raise TypeError('adjacency_matrices must be a list') + if len(adjacency_matrices) == 0: + raise ValueError('adjacency_matrices must not be empty') + if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): + raise TypeError('adjacency_matrices elements must be of class torch.Tensor') + if not all(m.is_sparse for m in adjacency_matrices): + raise ValueError('adjacency_matrices elements must be sparse') + keep_prob = float(keep_prob) + if not isinstance(activation, types.FunctionType): + raise TypeError('activation must be a function') + + self.in_channels = in_channels + self.out_channels = out_channels + self.adjacency_matrices = adjacency_matrices + self.keep_prob = keep_prob + self.activation = activation + + self.num_row_nodes = len(adjacency_matrices[0]) + self.num_relation_types = len(adjacency_matrices) + + self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) + + self.weights = torch.cat([ + init_glorot(in_channels, out_channels) \ + for _ in range(self.num_relation_types) + ], dim=1) + + def forward(self, x) -> torch.Tensor: + if self.keep_prob < 1.: + x = dropout(x, self.keep_prob) + res = torch.sparse.mm(x, self.weights) \ + if x.is_sparse \ + else torch.mm(x, self.weights) + res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) + res = torch.cat(res) + res = torch.sparse.mm(self.adjacency_matrices, res) \ + if self.adjacency_matrices.is_sparse \ + else torch.mm(self.adjacency_matrices, res) + res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) + if self.activation is not None: + res = self.activation(res) + + return res + + +class FastConvLayer(torch.nn.Module): + def __init__(self, + input_dim: List[int], + output_dim: List[int], + data: Union[Data, PreparedData], + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, + **kwargs): + + super().__init__(**kwargs) + + self._check_params(input_dim, output_dim, data, keep_prob, + rel_activation, layer_activation) + + self.input_dim = input_dim + self.output_dim = output_dim + self.data = data + self.keep_prob = keep_prob + self.rel_activation = rel_activation + self.layer_activation = layer_activation + + self.is_sparse = False + self.next_layer_repr = None + self.build() + + def build(self): + self.next_layer_repr = torch.nn.ModuleList([ + torch.nn.ModuleList() \ + for _ in range(len(self.data.node_types)) + ]) + for fam in self.data.relation_families: + self.build_family(fam) + + def build_family(self, fam) -> None: + if fam.node_type_row == fam.node_type_column: + self.build_fam_one_node_type(fam) + else: + self.build_fam_two_node_types(fam) + + def build_fam_one_node_type(self, fam) -> None: + adjacency_matrices = [ + r.adjacency_matrix \ + for r in fam.relation_types + ] + conv = FastGraphConv(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], + adjacency_matrices, + self.keep_prob, + self.rel_activation) + conv.input_node_type = fam.node_type_column + self.next_layer_repr[fam.node_type_row].append(conv) + + def build_fam_two_node_types(self, fam) -> None: + adjacency_matrices = [ + r.adjacency_matrix \ + for r in fam.relation_types \ + if r.adjacency_matrix is not None + ] + + adjacency_matrices_backward = [ + r.adjacency_matrix_backward \ + for r in fam.relation_types \ + if r.adjacency_matrix_backward is not None + ] + + conv = FastGraphConv(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], + adjacency_matrices, + self.keep_prob, + self.rel_activation) + + conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], + self.output_dim[fam.node_type_column], + adjacency_matrices_backward, + self.keep_prob, + self.rel_activation) + + conv.input_node_type = fam.node_type_column + conv_backward.input_node_type = fam.node_type_row + + self.next_layer_repr[fam.node_type_row].append(conv) + self.next_layer_repr[fam.node_type_column].append(conv_backward) + + def forward(self, prev_layer_repr): + next_layer_repr = [ [] \ + for _ in range(len(self.data.node_types)) ] + for output_node_type in range(len(self.data.node_types)): + for conv in self.next_layer_repr[output_node_type]: + rep = conv(prev_layer_repr[conv.input_node_type]) + rep = torch.sum(rep, dim=0) + rep = torch.nn.functional.normalize(rep, p=2, dim=1) + next_layer_repr[output_node_type].append(rep) + if len(next_layer_repr[output_node_type]) == 0: + next_layer_repr[output_node_type] = \ + torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) + else: + next_layer_repr[output_node_type] = \ + sum(next_layer_repr[output_node_type]) + next_layer_repr[output_node_type] = \ + self.layer_activation(next_layer_repr[output_node_type]) + return next_layer_repr + + @staticmethod + def _check_params(input_dim, output_dim, data, keep_prob, + rel_activation, layer_activation): + + if not isinstance(input_dim, list): + raise ValueError('input_dim must be a list') + + if not output_dim: + raise ValueError('output_dim must be specified') + + if not isinstance(output_dim, list): + output_dim = [output_dim] * len(data.node_types) + + if not isinstance(data, Data) and not isinstance(data, PreparedData): + raise ValueError('data must be of type Data or PreparedData') diff --git a/src/triacontagon/fastdec.py b/src/triacontagon/fastdec.py new file mode 100644 index 0000000..ca08892 --- /dev/null +++ b/src/triacontagon/fastdec.py @@ -0,0 +1,138 @@ +import torch +from typing import List +from .trainprep import PreparedData +from dataclasses import dataclass +import random +from collections import defaultdict + + +@dataclass +class TrainingBatch(object): + relation_family_index: int + relation_type_index: int + node_type_row: int + node_type_column: int + edges: torch.Tensor + + +class FastBatcher(object): + def __init__(self, + prep_d: PreparedData, + batch_size: int) -> None: + + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instance of PreparedData') + + self.prep_d = prep_d + self.batch_size = int(batch_size) + + self.edges = None + self.build() + + def build(self): + self.edges = [] + for fam_idx, fam in enumerate(self.prep_d.relation_families): + edges = [] + targets = [] + edges_back = [] + targets_back = [] + for rel_idx, rel in enumerate(fam.relation_types): + edges.append(rel.edges_pos.train) + edges.append(rel.edges_neg.train) + targets.append(torch.ones(len(rel.edges_pos.train))) + targets.append(torch.zeros(len(rel.edges_neg.train))) + + edges_back.append(rel.edges_back_pos.train) + edges_back.append(rel.edges_back_neg.train) + targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) + targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) + + edges = torch.cat(edges) + targets = torch.cat(targets) + edges_back = torch.cat(edges_back) + targets_back = torch.cat(targets_back) + + order = torch.randperm(len(edges)) + edges = edges[order] + targets = targets[order] + + order_back = torch.randperm(len(edges_back)) + edges_back = edges_back[order_back] + targets_back = targets_back[order_back] + + self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, + 'edges': edges, 'targets': targets, 'ofs': 0}) + self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, + 'edges': edges_back, 'targets': targets_back, 'ofs': 0}) + + def __iter__(self): + while True: + edges = [ e for e in self.edges \ + if e['ofs'] < len(e['edges']) ] + # TODO: need to finish this + + def __iter_old__(self): + edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] + + offsets = {} + orders = {} + done = {} + + for fam_idx, fam in enumerate(self.prep_d.relation_families): + for rel_idx, rel in enumerate(fam.relation_types): + for et in edge_types: + done[fam_idx, rel_idx, et] = False + + while True: + fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() + fam = self.prep_d.relation_families[fam_idx] + + rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() + rel = fam.relation_types[rel_idx] + + et = random.choice(edge_types) + edges = getattr(rel, et).train + + key = (fam_idx, rel_idx, et) + if key not in orders: + orders[key] = torch.randperm(len(edges)) + offsets[key] = 0 + + ord = orders[key] + ofs = offsets[key] + + nt_row = rel.node_type_row + nt_col = rel.node_type_column + + if 'back' in et: + nt_row, nt_col = nt_col, nt_row + + if ofs < len(edges): + offsets[key] += self.batch_size + ord = ord[ofs:ofs+self.batch_size] + edges = edges[ord] + yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) + else: + done[key] = True + + + + + for fam in self.prep_d.relation_families: + edges = [] + for rel in fam.relation_types: + edges.append(rel.edges_pos.train) + edges.append(rel.edges_back_pos.train) + edges.append(rel.edges_neg.train) + edges.append(rel.edges_back_neg.train) + edges = torch.cat(e) + + + +class FastDecLayer(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, + last_layer_repr: List[torch.Tensor], + training_batch: TrainingBatch): diff --git a/src/triacontagon/fastloop.py b/src/triacontagon/fastloop.py new file mode 100644 index 0000000..f955932 --- /dev/null +++ b/src/triacontagon/fastloop.py @@ -0,0 +1,166 @@ +from .fastmodel import FastModel +from .trainprep import PreparedData +import torch +from typing import Callable +from types import FunctionType +import time +import random + + +class FastBatcher(object): + def __init__(self, prep_d: PreparedData, batch_size: int, + shuffle: bool, generator: torch.Generator, + part_type: str) -> None: + + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instance of PreparedData') + + if not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + + if part_type not in ['train', 'val', 'test']: + raise ValueError('part_type must be set to train, val or test') + + self.prep_d = prep_d + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + self.generator = generator + self.part_type = part_type + + self.edges = None + self.targets = None + self.build() + + def build(self): + self.edges = [] + self.targets = [] + + for fam in self.prep_d.relation_families: + edges = [] + targets = [] + for i, rel in enumerate(fam.relation_types): + + edges_pos = getattr(rel.edges_pos, self.part_type) + edges_neg = getattr(rel.edges_neg, self.part_type) + edges_back_pos = getattr(rel.edges_back_pos, self.part_type) + edges_back_neg = getattr(rel.edges_back_neg, self.part_type) + + e = torch.cat([ edges_pos, + torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) + e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) + t = torch.ones(len(e)) + edges.append(e) + targets.append(t) + + e = torch.cat([ edges_neg, + torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) + e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) + t = torch.zeros(len(e)) + edges.append(e) + targets.append(t) + + edges = torch.cat(edges) + targets = torch.cat(targets) + + self.edges.append(edges) + self.targets.append(targets) + + # print(self.edges) + # print(self.targets) + + if self.shuffle: + self.shuffle_families() + + def shuffle_families(self): + for i in range(len(self.edges)): + edges = self.edges[i] + targets = self.targets[i] + order = torch.randperm(len(edges), generator=self.generator) + self.edges[i] = edges[order] + self.targets[i] = targets[order] + + def __iter__(self): + offsets = [ 0 for _ in self.edges ] + + while True: + choice = [ i for i in range(len(offsets)) \ + if offsets[i] < len(self.edges[i]) ] + if len(choice) == 0: + break + fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() + ofs = offsets[fam_idx] + edges = self.edges[fam_idx][ofs:ofs + self.batch_size] + targets = self.targets[fam_idx][ofs:ofs + self.batch_size] + offsets[fam_idx] += self.batch_size + yield (fam_idx, edges, targets) + + +class FastLoop(object): + def __init__( + self, + model: FastModel, + lr: float = 0.001, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ + torch.nn.functional.binary_cross_entropy_with_logits, + batch_size: int = 100, + shuffle: bool = True, + generator: torch.Generator = None) -> None: + + self._check_params(model, loss, generator) + + self.model = model + self.lr = float(lr) + self.loss = loss + self.batch_size = int(batch_size) + self.shuffle = bool(shuffle) + self.generator = generator or torch.default_generator + + self.opt = None + + self.build() + + def _check_params(self, model, loss, generator): + if not isinstance(model, FastModel): + raise TypeError('model must be an instance of FastModel') + + if not isinstance(loss, FunctionType): + raise TypeError('loss must be a function') + + if generator is not None and not isinstance(generator, torch.Generator): + raise TypeError('generator must be an instance of torch.Generator') + + def build(self) -> None: + opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.opt = opt + + def run_epoch(self): + prep_d = self.model.prep_d + + batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, + shuffle = self.shuffle, generator=self.generator) + # pred = self.model(None) + # n = len(list(iter(batch))) + loss_sum = 0 + for fam_idx, edges, targets in batcher: + self.opt.zero_grad() + pred = self.model(None) + + # process pred, get input and targets + input = pred[fam_idx][edges[:, 0], edges[:, 1]] + + loss = self.loss(input, targets) + loss.backward() + self.opt.step() + loss_sum += loss.detach().cpu().item() + return loss_sum + + + def train(self, max_epochs): + best_loss = None + best_epoch = None + for i in range(max_epochs): + loss = self.run_epoch() + if best_loss is None or loss < best_loss: + best_loss = loss + best_epoch = i + return loss, best_loss, best_epoch diff --git a/src/triacontagon/fastmodel.py b/src/triacontagon/fastmodel.py new file mode 100644 index 0000000..a68fe58 --- /dev/null +++ b/src/triacontagon/fastmodel.py @@ -0,0 +1,79 @@ +from .fastconv import FastConvLayer +from .bulkdec import BulkDecodeLayer +from .input import OneHotInputLayer +from .trainprep import PreparedData +import torch +import types +from typing import List, \ + Union, \ + Callable + + +class FastModel(torch.nn.Module): + def __init__(self, prep_d: PreparedData, + layer_dimensions: List[int] = [32, 64], + keep_prob: float = 1., + rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, + dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, + **kwargs) -> None: + + super().__init__(**kwargs) + + self._check_params(prep_d, layer_dimensions, rel_activation, + layer_activation, dec_activation) + + self.prep_d = prep_d + self.layer_dimensions = layer_dimensions + self.keep_prob = float(keep_prob) + self.rel_activation = rel_activation + self.layer_activation = layer_activation + self.dec_activation = dec_activation + + self.seq = None + self.build() + + def build(self): + in_layer = OneHotInputLayer(self.prep_d) + last_output_dim = in_layer.output_dim + seq = [ in_layer ] + + for dim in self.layer_dimensions: + conv_layer = FastConvLayer(input_dim = last_output_dim, + output_dim = [dim] * len(self.prep_d.node_types), + data = self.prep_d, + keep_prob = self.keep_prob, + rel_activation = self.rel_activation, + layer_activation = self.layer_activation) + last_output_dim = conv_layer.output_dim + seq.append(conv_layer) + + dec_layer = BulkDecodeLayer(input_dim = last_output_dim, + data = self.prep_d, + keep_prob = self.keep_prob, + activation = self.dec_activation) + seq.append(dec_layer) + + seq = torch.nn.Sequential(*seq) + self.seq = seq + + def forward(self, _): + return self.seq(None) + + def _check_params(self, prep_d, layer_dimensions, rel_activation, + layer_activation, dec_activation): + + if not isinstance(prep_d, PreparedData): + raise TypeError('prep_d must be an instanced of PreparedData') + + if not isinstance(layer_dimensions, list): + raise TypeError('layer_dimensions must be a list') + + if not isinstance(rel_activation, types.FunctionType): + raise TypeError('rel_activation must be a function') + + if not isinstance(layer_activation, types.FunctionType): + raise TypeError('layer_activation must be a function') + + if not isinstance(dec_activation, types.FunctionType): + raise TypeError('dec_activation must be a function') diff --git a/src/triacontagon/input.py b/src/triacontagon/input.py new file mode 100644 index 0000000..3bf5824 --- /dev/null +++ b/src/triacontagon/input.py @@ -0,0 +1,79 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +from typing import Union, \ + List +from .data import Data + + +class InputLayer(torch.nn.Module): + def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, + **kwargs) -> None: + + output_dim = output_dim or \ + list(map(lambda a: a.count, data.node_types)) + + if not isinstance(output_dim, list): + output_dim = [output_dim,] * len(data.node_types) + + super().__init__(**kwargs) + self.output_dim = output_dim + self.data = data + + self.is_sparse=False + self.node_reps = None + self.build() + + def build(self) -> None: + self.node_reps = [] + for i, nt in enumerate(self.data.node_types): + reps = torch.rand(nt.count, self.output_dim[i]) + reps = torch.nn.Parameter(reps) + self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self, x) -> List[torch.nn.Parameter]: + return self.node_reps + + def __repr__(self) -> str: + s = '' + s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() + + +class OneHotInputLayer(torch.nn.Module): + def __init__(self, data: Data, **kwargs) -> None: + output_dim = [ a.count for a in data.node_types ] + super().__init__(**kwargs) + self.output_dim = output_dim + self.data = data + + self.is_sparse=True + self.node_reps = None + self.build() + + def build(self) -> None: + self.node_reps = torch.nn.ParameterList() + for i, nt in enumerate(self.data.node_types): + reps = torch.eye(nt.count).to_sparse() + reps = torch.nn.Parameter(reps, requires_grad=False) + # self.register_parameter('node_reps[%d]' % i, reps) + self.node_reps.append(reps) + + def forward(self, x) -> List[torch.nn.Parameter]: + return self.node_reps + + def __repr__(self) -> str: + s = '' + s += 'Icosagon one-hot input layer\n' + s += ' # of node types: %d\n' % len(self.data.node_types) + for nt in self.data.node_types: + s += ' - %s (%d)\n' % (nt.name, nt.count) + return s.strip() diff --git a/src/triacontagon/normalize.py b/src/triacontagon/normalize.py new file mode 100644 index 0000000..e13fb05 --- /dev/null +++ b/src/triacontagon/normalize.py @@ -0,0 +1,145 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import numpy as np +import scipy.sparse as sp +import torch + + +def _check_tensor(adj_mat): + if not isinstance(adj_mat, torch.Tensor): + raise ValueError('adj_mat must be a torch.Tensor') + + +def _check_sparse(adj_mat): + if not adj_mat.is_sparse: + raise ValueError('adj_mat must be sparse') + + +def _check_dense(adj_mat): + if adj_mat.is_sparse: + raise ValueError('adj_mat must be dense') + + +def _check_square(adj_mat): + if len(adj_mat.shape) != 2 or \ + adj_mat.shape[0] != adj_mat.shape[1]: + raise ValueError('adj_mat must be a square matrix') + + +def _check_2d(adj_mat): + if len(adj_mat.shape) != 2: + raise ValueError('adj_mat must be a square matrix') + + +def _sparse_coo_tensor(indices, values, size): + ctor = { torch.float32: torch.sparse.FloatTensor, + torch.float32: torch.sparse.DoubleTensor, + torch.uint8: torch.sparse.ByteTensor, + torch.long: torch.sparse.LongTensor, + torch.int: torch.sparse.IntTensor, + torch.short: torch.sparse.ShortTensor, + torch.bool: torch.sparse.ByteTensor }[values.dtype] + return ctor(indices, values, size) + + +def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_sparse(adj_mat) + _check_square(adj_mat) + + adj_mat = adj_mat.coalesce() + indices = adj_mat.indices() + values = adj_mat.values() + + eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, + device=adj_mat.device).view(1, -1) + eye_indices = torch.cat((eye_indices, eye_indices), 0) + eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, + device=adj_mat.device) + + indices = torch.cat((indices, eye_indices), 1) + values = torch.cat((values, eye_values), 0) + + adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) + + return adj_mat + + +def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_sparse(adj_mat) + _check_square(adj_mat) + + adj_mat = add_eye_sparse(adj_mat) + adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat) + + return adj_mat + + +def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_dense(adj_mat) + _check_square(adj_mat) + + adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, + device=adj_mat.device) + adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) + + return adj_mat + + +def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_square(adj_mat) + + if adj_mat.is_sparse: + return norm_adj_mat_one_node_type_sparse(adj_mat) + else: + return norm_adj_mat_one_node_type_dense(adj_mat) + + +def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_sparse(adj_mat) + _check_2d(adj_mat) + + adj_mat = adj_mat.coalesce() + indices = adj_mat.indices() + values = adj_mat.values() + degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) + degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) + degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) + degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) + values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) + adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) + + return adj_mat + + +def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_dense(adj_mat) + _check_2d(adj_mat) + + degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32) + degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32) + degrees_row = torch.sqrt(degrees_row) + degrees_col = torch.sqrt(degrees_col) + adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row + adj_mat = adj_mat / degrees_col + + return adj_mat + + +def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor: + _check_tensor(adj_mat) + _check_2d(adj_mat) + + if adj_mat.is_sparse: + return norm_adj_mat_two_node_types_sparse(adj_mat) + else: + return norm_adj_mat_two_node_types_dense(adj_mat) diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py new file mode 100644 index 0000000..7c55944 --- /dev/null +++ b/src/triacontagon/sampling.py @@ -0,0 +1,47 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import numpy as np +import torch +import torch.utils.data +from typing import List, \ + Union + + +def fixed_unigram_candidate_sampler( + true_classes: Union[np.array, torch.Tensor], + unigrams: List[Union[int, float]], + distortion: float = 1.): + + if isinstance(true_classes, torch.Tensor): + true_classes = true_classes.detach().cpu().numpy() + + if isinstance(unigrams, torch.Tensor): + unigrams = unigrams.detach().cpu().numpy() + + if len(true_classes.shape) != 2: + raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') + + num_samples = true_classes.shape[0] + unigrams = np.array(unigrams) + if distortion != 1.: + unigrams = unigrams.astype(np.float64) ** distortion + # print('unigrams:', unigrams) + indices = np.arange(num_samples) + result = np.zeros(num_samples, dtype=np.int64) + while len(indices) > 0: + # print('len(indices):', len(indices)) + sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) + candidates = np.array(list(sampler)) + candidates = np.reshape(candidates, (len(indices), 1)) + # print('candidates:', candidates) + # print('true_classes:', true_classes[indices, :]) + result[indices] = candidates.T + mask = (candidates == true_classes[indices, :]) + mask = mask.sum(1).astype(np.bool) + # print('mask:', mask) + indices = indices[mask] + return torch.tensor(result) diff --git a/src/triacontagon/trainprep.py b/src/triacontagon/trainprep.py new file mode 100644 index 0000000..c49300a --- /dev/null +++ b/src/triacontagon/trainprep.py @@ -0,0 +1,215 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +from .sampling import fixed_unigram_candidate_sampler +import torch +from dataclasses import dataclass, \ + field +from typing import Any, \ + List, \ + Tuple, \ + Dict +from .data import NodeType, \ + RelationType, \ + RelationTypeBase, \ + RelationFamily, \ + RelationFamilyBase, \ + Data +from collections import defaultdict +from .normalize import norm_adj_mat_one_node_type, \ + norm_adj_mat_two_node_types +import numpy as np + + +@dataclass +class TrainValTest(object): + train: Any + val: Any + test: Any + + +@dataclass +class PreparedRelationType(RelationTypeBase): + edges_pos: TrainValTest + edges_neg: TrainValTest + edges_back_pos: TrainValTest + edges_back_neg: TrainValTest + + +@dataclass +class PreparedRelationFamily(RelationFamilyBase): + relation_types: List[PreparedRelationType] + + +@dataclass +class PreparedData(object): + node_types: List[NodeType] + relation_families: List[PreparedRelationFamily] + + +def _empty_edge_list_tvt() -> TrainValTest: + return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) + + +def train_val_test_split_edges(edges: torch.Tensor, + ratios: TrainValTest) -> TrainValTest: + + if not isinstance(edges, torch.Tensor): + raise ValueError('edges must be a torch.Tensor') + + if len(edges.shape) != 2 or edges.shape[1] != 2: + raise ValueError('edges shape must be (num_edges, 2)') + + if not isinstance(ratios, TrainValTest): + raise ValueError('ratios must be a TrainValTest') + + if ratios.train + ratios.val + ratios.test != 1.0: + raise ValueError('Train, validation and test ratios must add up to 1') + + order = torch.randperm(len(edges)) + edges = edges[order, :] + n = round(len(edges) * ratios.train) + edges_train = edges[:n] + n_1 = round(len(edges) * (ratios.train + ratios.val)) + edges_val = edges[n:n_1] + edges_test = edges[n_1:] + + return TrainValTest(edges_train, edges_val, edges_test) + + +def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if adj_mat.is_sparse: + adj_mat = adj_mat.coalesce() + degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, + device=adj_mat.device) + degrees = degrees.index_add(0, adj_mat.indices()[1], + torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, + device=adj_mat.device)) + edges_pos = adj_mat.indices().transpose(0, 1) + else: + degrees = adj_mat.sum(0) + edges_pos = torch.nonzero(adj_mat) + return edges_pos, degrees + + +def prepare_adj_mat(adj_mat: torch.Tensor, + ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: + + if not isinstance(adj_mat, torch.Tensor): + raise ValueError('adj_mat must be a torch.Tensor') + + edges_pos, degrees = get_edges_and_degrees(adj_mat) + + neg_neighbors = fixed_unigram_candidate_sampler( + edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) + print(edges_pos.dtype) + print(neg_neighbors.dtype) + edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) + + edges_pos = train_val_test_split_edges(edges_pos, ratios) + edges_neg = train_val_test_split_edges(edges_neg, ratios) + + adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), + values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, + device=adj_mat.device) + + return adj_mat_train, edges_pos, edges_neg + + +def prep_rel_one_node_type(r: RelationType, + ratios: TrainValTest) -> PreparedRelationType: + + adj_mat = r.adjacency_matrix + adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() + + print('adj_mat_train:', adj_mat_train) + adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) + + return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, + adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, + edges_back_pos, edges_back_neg) + + +def prep_rel_two_node_types_sym(r: RelationType, + ratios: TrainValTest) -> PreparedRelationType: + + adj_mat = r.adjacency_matrix + adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + edges_back_pos, edges_back_neg = \ + _empty_edge_list_tvt(), _empty_edge_list_tvt() + + return PreparedRelationType(r.name, r.node_type_row, + r.node_type_column, + norm_adj_mat_two_node_types(adj_mat_train), + norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), + edges_pos, edges_neg, edges_back_pos, edges_back_neg) + + +def prep_rel_two_node_types_asym(r: RelationType, + ratios: TrainValTest) -> PreparedRelationType: + + if r.adjacency_matrix is not None: + adj_mat_train, edges_pos, edges_neg =\ + prepare_adj_mat(r.adjacency_matrix, ratios) + else: + adj_mat_train, edges_pos, edges_neg = \ + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() + + if r.adjacency_matrix_backward is not None: + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + prepare_adj_mat(r.adjacency_matrix_backward, ratios) + else: + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() + + return PreparedRelationType(r.name, r.node_type_row, + r.node_type_column, + norm_adj_mat_two_node_types(adj_mat_train), + norm_adj_mat_two_node_types(adj_mat_back_train), + edges_pos, edges_neg, edges_back_pos, edges_back_neg) + + +def prepare_relation_type(r: RelationType, + ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: + + if not isinstance(r, RelationType): + raise ValueError('r must be a RelationType') + + if not isinstance(ratios, TrainValTest): + raise ValueError('ratios must be a TrainValTest') + + if r.node_type_row == r.node_type_column: + return prep_rel_one_node_type(r, ratios) + elif is_symmetric: + return prep_rel_two_node_types_sym(r, ratios) + else: + return prep_rel_two_node_types_asym(r, ratios) + + +def prepare_relation_family(fam: RelationFamily, + ratios: TrainValTest) -> PreparedRelationFamily: + + relation_types = [] + + for r in fam.relation_types: + relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) + + return PreparedRelationFamily(fam.data, fam.name, + fam.node_type_row, fam.node_type_column, + fam.is_symmetric, fam.decoder_class, + relation_types) + + +def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: + if not isinstance(data, Data): + raise ValueError('data must be of class Data') + + relation_families = [ prepare_relation_family(fam, ratios) \ + for fam in data.relation_families ] + + return PreparedData(data.node_types, relation_families) diff --git a/src/triacontagon/weights.py b/src/triacontagon/weights.py new file mode 100644 index 0000000..2dcb7b4 --- /dev/null +++ b/src/triacontagon/weights.py @@ -0,0 +1,19 @@ +# +# Copyright (C) Stanislaw Adaszewski, 2020 +# License: GPLv3 +# + + +import torch +import numpy as np + + +def init_glorot(in_channels, out_channels, dtype=torch.float32): + """Create a weight variable with Glorot & Bengio (AISTATS 2010) + initialization. + """ + init_range = np.sqrt(6.0 / (in_channels + out_channels)) + initial = -init_range + 2 * init_range * \ + torch.rand(( in_channels, out_channels ), dtype=dtype) + initial = initial.requires_grad_(True) + return initial