@@ -0,0 +1,209 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
from collections import defaultdict | |||||
from dataclasses import dataclass, field | |||||
import torch | |||||
from typing import List, \ | |||||
Dict, \ | |||||
Tuple, \ | |||||
Any, \ | |||||
Type | |||||
from .decode import DEDICOMDecoder, \ | |||||
BilinearDecoder | |||||
import numpy as np | |||||
def _equal(x: torch.Tensor, y: torch.Tensor): | |||||
if x.is_sparse ^ y.is_sparse: | |||||
raise ValueError('Cannot mix sparse and dense tensors') | |||||
if not x.is_sparse: | |||||
return (x == y) | |||||
return ((x - y).coalesce().values() == 0) | |||||
@dataclass | |||||
class NodeType(object): | |||||
name: str | |||||
count: int | |||||
@dataclass | |||||
class RelationTypeBase(object): | |||||
name: str | |||||
node_type_row: int | |||||
node_type_column: int | |||||
adjacency_matrix: torch.Tensor | |||||
adjacency_matrix_backward: torch.Tensor | |||||
@dataclass | |||||
class RelationType(RelationTypeBase): | |||||
pass | |||||
@dataclass | |||||
class RelationFamilyBase(object): | |||||
data: 'Data' | |||||
name: str | |||||
node_type_row: int | |||||
node_type_column: int | |||||
is_symmetric: bool | |||||
decoder_class: Type | |||||
@dataclass | |||||
class RelationFamily(RelationFamilyBase): | |||||
relation_types: List[RelationType] = None | |||||
def __post_init__(self) -> None: | |||||
if not self.is_symmetric and \ | |||||
self.decoder_class != DEDICOMDecoder and \ | |||||
self.decoder_class != BilinearDecoder: | |||||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
self.relation_types = [] | |||||
def add_relation_type(self, | |||||
name: str, adjacency_matrix: torch.Tensor, | |||||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||||
name = str(name) | |||||
node_type_row = self.node_type_row | |||||
node_type_column = self.node_type_column | |||||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||||
if adjacency_matrix is not None and \ | |||||
not isinstance(adjacency_matrix, torch.Tensor): | |||||
raise ValueError('adjacency_matrix must be a torch.Tensor') | |||||
if adjacency_matrix_backward is not None \ | |||||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||||
if adjacency_matrix is not None and \ | |||||
adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
self.data.node_types[node_type_column].count): | |||||
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||||
if adjacency_matrix_backward is not None and \ | |||||
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||||
self.data.node_types[node_type_row].count): | |||||
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||||
if node_type_row == node_type_column and \ | |||||
adjacency_matrix_backward is not None: | |||||
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||||
if self.is_symmetric and adjacency_matrix_backward is not None: | |||||
raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||||
if self.is_symmetric and node_type_row == node_type_column and \ | |||||
not torch.all(_equal(adjacency_matrix, | |||||
adjacency_matrix.transpose(0, 1))): | |||||
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||||
if not self.is_symmetric and node_type_row != node_type_column and \ | |||||
adjacency_matrix_backward is None: | |||||
raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||||
if self.is_symmetric and node_type_row != node_type_column: | |||||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
self.relation_types.append(RelationType(name, | |||||
node_type_row, node_type_column, | |||||
adjacency_matrix, adjacency_matrix_backward)) | |||||
def node_name(self, index): | |||||
return self.data.node_types[index].name | |||||
def __repr__(self): | |||||
s = 'Relation family %s' % self.name | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | |||||
def repr_indented(self): | |||||
s = ' - %s' % self.name | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.node_type_row == self.node_type_column \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | |||||
class Data(object): | |||||
node_types: List[NodeType] | |||||
relation_families: List[RelationFamily] | |||||
def __init__(self) -> None: | |||||
self.node_types = [] | |||||
self.relation_families = [] | |||||
def add_node_type(self, name: str, count: int) -> None: | |||||
name = str(name) | |||||
count = int(count) | |||||
if not name: | |||||
raise ValueError('You must provide a non-empty node type name') | |||||
if count <= 0: | |||||
raise ValueError('You must provide a positive node count') | |||||
self.node_types.append(NodeType(name, count)) | |||||
def add_relation_family(self, name: str, node_type_row: int, | |||||
node_type_column: int, is_symmetric: bool, | |||||
decoder_class: Type = DEDICOMDecoder): | |||||
name = str(name) | |||||
node_type_row = int(node_type_row) | |||||
node_type_column = int(node_type_column) | |||||
is_symmetric = bool(is_symmetric) | |||||
if node_type_row < 0 or node_type_row >= len(self.node_types): | |||||
raise ValueError('node_type_row outside of the valid range of node types') | |||||
if node_type_column < 0 or node_type_column >= len(self.node_types): | |||||
raise ValueError('node_type_column outside of the valid range of node types') | |||||
fam = RelationFamily(self, name, node_type_row, node_type_column, | |||||
is_symmetric, decoder_class) | |||||
self.relation_families.append(fam) | |||||
return fam | |||||
def __repr__(self): | |||||
n = len(self.node_types) | |||||
if n == 0: | |||||
return 'Empty Icosagon Data' | |||||
s = '' | |||||
s += 'Icosagon Data with:\n' | |||||
s += '- ' + str(n) + ' node type(s):\n' | |||||
for nt in self.node_types: | |||||
s += ' - ' + nt.name + '\n' | |||||
if len(self.relation_families) == 0: | |||||
s += '- No relation families\n' | |||||
return s.strip() | |||||
s += '- %d relation families:\n' % len(self.relation_families) | |||||
for fam in self.relation_families: | |||||
s += fam.repr_indented() + '\n' | |||||
return s.strip() |
@@ -0,0 +1,123 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from .weights import init_glorot | |||||
from .dropout import dropout | |||||
class DEDICOMDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||||
self.local_variation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
relation = torch.diag(self.local_variation[relation_index]) | |||||
product1 = torch.mm(inputs_row, relation) | |||||
product2 = torch.mm(product1, self.global_interaction) | |||||
product3 = torch.mm(product2, relation) | |||||
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
class DistMultDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.relation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
relation = torch.diag(self.relation[relation_index]) | |||||
intermediate_product = torch.mm(inputs_row, relation) | |||||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
class BilinearDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.relation = torch.nn.ParameterList([ | |||||
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||||
for _ in range(num_relation_types) | |||||
]) | |||||
def forward(self, inputs_row, inputs_col, relation_index): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) | |||||
class InnerProductDecoder(torch.nn.Module): | |||||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||||
activation=torch.sigmoid, **kwargs): | |||||
super().__init__(**kwargs) | |||||
self.input_dim = input_dim | |||||
self.num_relation_types = num_relation_types | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
def forward(self, inputs_row, inputs_col, _): | |||||
inputs_row = dropout(inputs_row, self.keep_prob) | |||||
inputs_col = dropout(inputs_col, self.keep_prob) | |||||
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
rec = torch.flatten(rec) | |||||
return self.activation(rec) |
@@ -0,0 +1,42 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from .normalize import _sparse_coo_tensor | |||||
def dropout_sparse(x, keep_prob): | |||||
x = x.coalesce() | |||||
i = x._indices() | |||||
v = x._values() | |||||
size = x.size() | |||||
n = keep_prob + torch.rand(len(v)) | |||||
n = torch.floor(n).to(torch.bool) | |||||
i = i[:,n] | |||||
v = v[n] | |||||
x = _sparse_coo_tensor(i, v, size=size) | |||||
return x * (1./keep_prob) | |||||
def dropout_dense(x, keep_prob): | |||||
# print('dropout_dense()') | |||||
x = x.clone() | |||||
i = torch.nonzero(x) | |||||
n = keep_prob + torch.rand(len(i)) | |||||
n = (1. - torch.floor(n)).to(torch.bool) | |||||
x[i[n, 0], i[n, 1]] = 0. | |||||
return x * (1./keep_prob) | |||||
def dropout(x, keep_prob): | |||||
if x.is_sparse: | |||||
return dropout_sparse(x, keep_prob) | |||||
else: | |||||
return dropout_dense(x, keep_prob) |
@@ -0,0 +1,255 @@ | |||||
from typing import List, \ | |||||
Union, \ | |||||
Callable | |||||
from .data import Data, \ | |||||
RelationFamily | |||||
from .trainprep import PreparedData, \ | |||||
PreparedRelationFamily | |||||
import torch | |||||
from .weights import init_glorot | |||||
from .normalize import _sparse_coo_tensor | |||||
import types | |||||
def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('The list of matrices must be non-empty') | |||||
if not all(m.is_sparse for m in matrices): | |||||
raise ValueError('All matrices must be sparse') | |||||
if not all(len(m.shape) == 2 for m in matrices): | |||||
raise ValueError('All matrices must be 2D') | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
col_offset = 0 | |||||
for m in matrices: | |||||
ind = m._indices().clone() | |||||
ind[0] += row_offset | |||||
ind[1] += col_offset | |||||
indices.append(ind) | |||||
values.append(m._values()) | |||||
row_offset += m.shape[0] | |||||
col_offset += m.shape[1] | |||||
indices = torch.cat(indices, dim=1) | |||||
values = torch.cat(values) | |||||
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||||
def _cat(matrices: List[torch.Tensor]): | |||||
if len(matrices) == 0: | |||||
raise ValueError('Empty list passed to _cat()') | |||||
n = sum(a.is_sparse for a in matrices) | |||||
if n != 0 and n != len(matrices): | |||||
raise ValueError('All matrices must have the same layout (dense or sparse)') | |||||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||||
if not matrices[0].is_sparse: | |||||
return torch.cat(matrices) | |||||
total_rows = sum(a.shape[0] for a in matrices) | |||||
indices = [] | |||||
values = [] | |||||
row_offset = 0 | |||||
for a in matrices: | |||||
ind = a._indices().clone() | |||||
val = a._values() | |||||
ind[0] += row_offset | |||||
ind = ind.transpose(0, 1) | |||||
indices.append(ind) | |||||
values.append(val) | |||||
row_offset += a.shape[0] | |||||
indices = torch.cat(indices).transpose(0, 1) | |||||
values = torch.cat(values) | |||||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||||
return res | |||||
class FastGraphConv(torch.nn.Module): | |||||
def __init__(self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
adjacency_matrices: List[torch.Tensor], | |||||
keep_prob: float = 1., | |||||
activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
in_channels = int(in_channels) | |||||
out_channels = int(out_channels) | |||||
if not isinstance(adjacency_matrices, list): | |||||
raise TypeError('adjacency_matrices must be a list') | |||||
if len(adjacency_matrices) == 0: | |||||
raise ValueError('adjacency_matrices must not be empty') | |||||
if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): | |||||
raise TypeError('adjacency_matrices elements must be of class torch.Tensor') | |||||
if not all(m.is_sparse for m in adjacency_matrices): | |||||
raise ValueError('adjacency_matrices elements must be sparse') | |||||
keep_prob = float(keep_prob) | |||||
if not isinstance(activation, types.FunctionType): | |||||
raise TypeError('activation must be a function') | |||||
self.in_channels = in_channels | |||||
self.out_channels = out_channels | |||||
self.adjacency_matrices = adjacency_matrices | |||||
self.keep_prob = keep_prob | |||||
self.activation = activation | |||||
self.num_row_nodes = len(adjacency_matrices[0]) | |||||
self.num_relation_types = len(adjacency_matrices) | |||||
self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) | |||||
self.weights = torch.cat([ | |||||
init_glorot(in_channels, out_channels) \ | |||||
for _ in range(self.num_relation_types) | |||||
], dim=1) | |||||
def forward(self, x) -> torch.Tensor: | |||||
if self.keep_prob < 1.: | |||||
x = dropout(x, self.keep_prob) | |||||
res = torch.sparse.mm(x, self.weights) \ | |||||
if x.is_sparse \ | |||||
else torch.mm(x, self.weights) | |||||
res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) | |||||
res = torch.cat(res) | |||||
res = torch.sparse.mm(self.adjacency_matrices, res) \ | |||||
if self.adjacency_matrices.is_sparse \ | |||||
else torch.mm(self.adjacency_matrices, res) | |||||
res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) | |||||
if self.activation is not None: | |||||
res = self.activation(res) | |||||
return res | |||||
class FastConvLayer(torch.nn.Module): | |||||
def __init__(self, | |||||
input_dim: List[int], | |||||
output_dim: List[int], | |||||
data: Union[Data, PreparedData], | |||||
keep_prob: float = 1., | |||||
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
**kwargs): | |||||
super().__init__(**kwargs) | |||||
self._check_params(input_dim, output_dim, data, keep_prob, | |||||
rel_activation, layer_activation) | |||||
self.input_dim = input_dim | |||||
self.output_dim = output_dim | |||||
self.data = data | |||||
self.keep_prob = keep_prob | |||||
self.rel_activation = rel_activation | |||||
self.layer_activation = layer_activation | |||||
self.is_sparse = False | |||||
self.next_layer_repr = None | |||||
self.build() | |||||
def build(self): | |||||
self.next_layer_repr = torch.nn.ModuleList([ | |||||
torch.nn.ModuleList() \ | |||||
for _ in range(len(self.data.node_types)) | |||||
]) | |||||
for fam in self.data.relation_families: | |||||
self.build_family(fam) | |||||
def build_family(self, fam) -> None: | |||||
if fam.node_type_row == fam.node_type_column: | |||||
self.build_fam_one_node_type(fam) | |||||
else: | |||||
self.build_fam_two_node_types(fam) | |||||
def build_fam_one_node_type(self, fam) -> None: | |||||
adjacency_matrices = [ | |||||
r.adjacency_matrix \ | |||||
for r in fam.relation_types | |||||
] | |||||
conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
self.output_dim[fam.node_type_row], | |||||
adjacency_matrices, | |||||
self.keep_prob, | |||||
self.rel_activation) | |||||
conv.input_node_type = fam.node_type_column | |||||
self.next_layer_repr[fam.node_type_row].append(conv) | |||||
def build_fam_two_node_types(self, fam) -> None: | |||||
adjacency_matrices = [ | |||||
r.adjacency_matrix \ | |||||
for r in fam.relation_types \ | |||||
if r.adjacency_matrix is not None | |||||
] | |||||
adjacency_matrices_backward = [ | |||||
r.adjacency_matrix_backward \ | |||||
for r in fam.relation_types \ | |||||
if r.adjacency_matrix_backward is not None | |||||
] | |||||
conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||||
self.output_dim[fam.node_type_row], | |||||
adjacency_matrices, | |||||
self.keep_prob, | |||||
self.rel_activation) | |||||
conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||||
self.output_dim[fam.node_type_column], | |||||
adjacency_matrices_backward, | |||||
self.keep_prob, | |||||
self.rel_activation) | |||||
conv.input_node_type = fam.node_type_column | |||||
conv_backward.input_node_type = fam.node_type_row | |||||
self.next_layer_repr[fam.node_type_row].append(conv) | |||||
self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||||
def forward(self, prev_layer_repr): | |||||
next_layer_repr = [ [] \ | |||||
for _ in range(len(self.data.node_types)) ] | |||||
for output_node_type in range(len(self.data.node_types)): | |||||
for conv in self.next_layer_repr[output_node_type]: | |||||
rep = conv(prev_layer_repr[conv.input_node_type]) | |||||
rep = torch.sum(rep, dim=0) | |||||
rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||||
next_layer_repr[output_node_type].append(rep) | |||||
if len(next_layer_repr[output_node_type]) == 0: | |||||
next_layer_repr[output_node_type] = \ | |||||
torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||||
else: | |||||
next_layer_repr[output_node_type] = \ | |||||
sum(next_layer_repr[output_node_type]) | |||||
next_layer_repr[output_node_type] = \ | |||||
self.layer_activation(next_layer_repr[output_node_type]) | |||||
return next_layer_repr | |||||
@staticmethod | |||||
def _check_params(input_dim, output_dim, data, keep_prob, | |||||
rel_activation, layer_activation): | |||||
if not isinstance(input_dim, list): | |||||
raise ValueError('input_dim must be a list') | |||||
if not output_dim: | |||||
raise ValueError('output_dim must be specified') | |||||
if not isinstance(output_dim, list): | |||||
output_dim = [output_dim] * len(data.node_types) | |||||
if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||||
raise ValueError('data must be of type Data or PreparedData') |
@@ -0,0 +1,138 @@ | |||||
import torch | |||||
from typing import List | |||||
from .trainprep import PreparedData | |||||
from dataclasses import dataclass | |||||
import random | |||||
from collections import defaultdict | |||||
@dataclass | |||||
class TrainingBatch(object): | |||||
relation_family_index: int | |||||
relation_type_index: int | |||||
node_type_row: int | |||||
node_type_column: int | |||||
edges: torch.Tensor | |||||
class FastBatcher(object): | |||||
def __init__(self, | |||||
prep_d: PreparedData, | |||||
batch_size: int) -> None: | |||||
if not isinstance(prep_d, PreparedData): | |||||
raise TypeError('prep_d must be an instance of PreparedData') | |||||
self.prep_d = prep_d | |||||
self.batch_size = int(batch_size) | |||||
self.edges = None | |||||
self.build() | |||||
def build(self): | |||||
self.edges = [] | |||||
for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||||
edges = [] | |||||
targets = [] | |||||
edges_back = [] | |||||
targets_back = [] | |||||
for rel_idx, rel in enumerate(fam.relation_types): | |||||
edges.append(rel.edges_pos.train) | |||||
edges.append(rel.edges_neg.train) | |||||
targets.append(torch.ones(len(rel.edges_pos.train))) | |||||
targets.append(torch.zeros(len(rel.edges_neg.train))) | |||||
edges_back.append(rel.edges_back_pos.train) | |||||
edges_back.append(rel.edges_back_neg.train) | |||||
targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) | |||||
targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) | |||||
edges = torch.cat(edges) | |||||
targets = torch.cat(targets) | |||||
edges_back = torch.cat(edges_back) | |||||
targets_back = torch.cat(targets_back) | |||||
order = torch.randperm(len(edges)) | |||||
edges = edges[order] | |||||
targets = targets[order] | |||||
order_back = torch.randperm(len(edges_back)) | |||||
edges_back = edges_back[order_back] | |||||
targets_back = targets_back[order_back] | |||||
self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, | |||||
'edges': edges, 'targets': targets, 'ofs': 0}) | |||||
self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, | |||||
'edges': edges_back, 'targets': targets_back, 'ofs': 0}) | |||||
def __iter__(self): | |||||
while True: | |||||
edges = [ e for e in self.edges \ | |||||
if e['ofs'] < len(e['edges']) ] | |||||
# TODO: need to finish this | |||||
def __iter_old__(self): | |||||
edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] | |||||
offsets = {} | |||||
orders = {} | |||||
done = {} | |||||
for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||||
for rel_idx, rel in enumerate(fam.relation_types): | |||||
for et in edge_types: | |||||
done[fam_idx, rel_idx, et] = False | |||||
while True: | |||||
fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() | |||||
fam = self.prep_d.relation_families[fam_idx] | |||||
rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() | |||||
rel = fam.relation_types[rel_idx] | |||||
et = random.choice(edge_types) | |||||
edges = getattr(rel, et).train | |||||
key = (fam_idx, rel_idx, et) | |||||
if key not in orders: | |||||
orders[key] = torch.randperm(len(edges)) | |||||
offsets[key] = 0 | |||||
ord = orders[key] | |||||
ofs = offsets[key] | |||||
nt_row = rel.node_type_row | |||||
nt_col = rel.node_type_column | |||||
if 'back' in et: | |||||
nt_row, nt_col = nt_col, nt_row | |||||
if ofs < len(edges): | |||||
offsets[key] += self.batch_size | |||||
ord = ord[ofs:ofs+self.batch_size] | |||||
edges = edges[ord] | |||||
yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) | |||||
else: | |||||
done[key] = True | |||||
for fam in self.prep_d.relation_families: | |||||
edges = [] | |||||
for rel in fam.relation_types: | |||||
edges.append(rel.edges_pos.train) | |||||
edges.append(rel.edges_back_pos.train) | |||||
edges.append(rel.edges_neg.train) | |||||
edges.append(rel.edges_back_neg.train) | |||||
edges = torch.cat(e) | |||||
class FastDecLayer(torch.nn.Module): | |||||
def __init__(self, **kwargs): | |||||
super().__init__(**kwargs) | |||||
def forward(self, | |||||
last_layer_repr: List[torch.Tensor], | |||||
training_batch: TrainingBatch): |
@@ -0,0 +1,166 @@ | |||||
from .fastmodel import FastModel | |||||
from .trainprep import PreparedData | |||||
import torch | |||||
from typing import Callable | |||||
from types import FunctionType | |||||
import time | |||||
import random | |||||
class FastBatcher(object): | |||||
def __init__(self, prep_d: PreparedData, batch_size: int, | |||||
shuffle: bool, generator: torch.Generator, | |||||
part_type: str) -> None: | |||||
if not isinstance(prep_d, PreparedData): | |||||
raise TypeError('prep_d must be an instance of PreparedData') | |||||
if not isinstance(generator, torch.Generator): | |||||
raise TypeError('generator must be an instance of torch.Generator') | |||||
if part_type not in ['train', 'val', 'test']: | |||||
raise ValueError('part_type must be set to train, val or test') | |||||
self.prep_d = prep_d | |||||
self.batch_size = int(batch_size) | |||||
self.shuffle = bool(shuffle) | |||||
self.generator = generator | |||||
self.part_type = part_type | |||||
self.edges = None | |||||
self.targets = None | |||||
self.build() | |||||
def build(self): | |||||
self.edges = [] | |||||
self.targets = [] | |||||
for fam in self.prep_d.relation_families: | |||||
edges = [] | |||||
targets = [] | |||||
for i, rel in enumerate(fam.relation_types): | |||||
edges_pos = getattr(rel.edges_pos, self.part_type) | |||||
edges_neg = getattr(rel.edges_neg, self.part_type) | |||||
edges_back_pos = getattr(rel.edges_back_pos, self.part_type) | |||||
edges_back_neg = getattr(rel.edges_back_neg, self.part_type) | |||||
e = torch.cat([ edges_pos, | |||||
torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) | |||||
e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) | |||||
t = torch.ones(len(e)) | |||||
edges.append(e) | |||||
targets.append(t) | |||||
e = torch.cat([ edges_neg, | |||||
torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) | |||||
e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) | |||||
t = torch.zeros(len(e)) | |||||
edges.append(e) | |||||
targets.append(t) | |||||
edges = torch.cat(edges) | |||||
targets = torch.cat(targets) | |||||
self.edges.append(edges) | |||||
self.targets.append(targets) | |||||
# print(self.edges) | |||||
# print(self.targets) | |||||
if self.shuffle: | |||||
self.shuffle_families() | |||||
def shuffle_families(self): | |||||
for i in range(len(self.edges)): | |||||
edges = self.edges[i] | |||||
targets = self.targets[i] | |||||
order = torch.randperm(len(edges), generator=self.generator) | |||||
self.edges[i] = edges[order] | |||||
self.targets[i] = targets[order] | |||||
def __iter__(self): | |||||
offsets = [ 0 for _ in self.edges ] | |||||
while True: | |||||
choice = [ i for i in range(len(offsets)) \ | |||||
if offsets[i] < len(self.edges[i]) ] | |||||
if len(choice) == 0: | |||||
break | |||||
fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() | |||||
ofs = offsets[fam_idx] | |||||
edges = self.edges[fam_idx][ofs:ofs + self.batch_size] | |||||
targets = self.targets[fam_idx][ofs:ofs + self.batch_size] | |||||
offsets[fam_idx] += self.batch_size | |||||
yield (fam_idx, edges, targets) | |||||
class FastLoop(object): | |||||
def __init__( | |||||
self, | |||||
model: FastModel, | |||||
lr: float = 0.001, | |||||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||||
torch.nn.functional.binary_cross_entropy_with_logits, | |||||
batch_size: int = 100, | |||||
shuffle: bool = True, | |||||
generator: torch.Generator = None) -> None: | |||||
self._check_params(model, loss, generator) | |||||
self.model = model | |||||
self.lr = float(lr) | |||||
self.loss = loss | |||||
self.batch_size = int(batch_size) | |||||
self.shuffle = bool(shuffle) | |||||
self.generator = generator or torch.default_generator | |||||
self.opt = None | |||||
self.build() | |||||
def _check_params(self, model, loss, generator): | |||||
if not isinstance(model, FastModel): | |||||
raise TypeError('model must be an instance of FastModel') | |||||
if not isinstance(loss, FunctionType): | |||||
raise TypeError('loss must be a function') | |||||
if generator is not None and not isinstance(generator, torch.Generator): | |||||
raise TypeError('generator must be an instance of torch.Generator') | |||||
def build(self) -> None: | |||||
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |||||
self.opt = opt | |||||
def run_epoch(self): | |||||
prep_d = self.model.prep_d | |||||
batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, | |||||
shuffle = self.shuffle, generator=self.generator) | |||||
# pred = self.model(None) | |||||
# n = len(list(iter(batch))) | |||||
loss_sum = 0 | |||||
for fam_idx, edges, targets in batcher: | |||||
self.opt.zero_grad() | |||||
pred = self.model(None) | |||||
# process pred, get input and targets | |||||
input = pred[fam_idx][edges[:, 0], edges[:, 1]] | |||||
loss = self.loss(input, targets) | |||||
loss.backward() | |||||
self.opt.step() | |||||
loss_sum += loss.detach().cpu().item() | |||||
return loss_sum | |||||
def train(self, max_epochs): | |||||
best_loss = None | |||||
best_epoch = None | |||||
for i in range(max_epochs): | |||||
loss = self.run_epoch() | |||||
if best_loss is None or loss < best_loss: | |||||
best_loss = loss | |||||
best_epoch = i | |||||
return loss, best_loss, best_epoch |
@@ -0,0 +1,79 @@ | |||||
from .fastconv import FastConvLayer | |||||
from .bulkdec import BulkDecodeLayer | |||||
from .input import OneHotInputLayer | |||||
from .trainprep import PreparedData | |||||
import torch | |||||
import types | |||||
from typing import List, \ | |||||
Union, \ | |||||
Callable | |||||
class FastModel(torch.nn.Module): | |||||
def __init__(self, prep_d: PreparedData, | |||||
layer_dimensions: List[int] = [32, 64], | |||||
keep_prob: float = 1., | |||||
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
**kwargs) -> None: | |||||
super().__init__(**kwargs) | |||||
self._check_params(prep_d, layer_dimensions, rel_activation, | |||||
layer_activation, dec_activation) | |||||
self.prep_d = prep_d | |||||
self.layer_dimensions = layer_dimensions | |||||
self.keep_prob = float(keep_prob) | |||||
self.rel_activation = rel_activation | |||||
self.layer_activation = layer_activation | |||||
self.dec_activation = dec_activation | |||||
self.seq = None | |||||
self.build() | |||||
def build(self): | |||||
in_layer = OneHotInputLayer(self.prep_d) | |||||
last_output_dim = in_layer.output_dim | |||||
seq = [ in_layer ] | |||||
for dim in self.layer_dimensions: | |||||
conv_layer = FastConvLayer(input_dim = last_output_dim, | |||||
output_dim = [dim] * len(self.prep_d.node_types), | |||||
data = self.prep_d, | |||||
keep_prob = self.keep_prob, | |||||
rel_activation = self.rel_activation, | |||||
layer_activation = self.layer_activation) | |||||
last_output_dim = conv_layer.output_dim | |||||
seq.append(conv_layer) | |||||
dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||||
data = self.prep_d, | |||||
keep_prob = self.keep_prob, | |||||
activation = self.dec_activation) | |||||
seq.append(dec_layer) | |||||
seq = torch.nn.Sequential(*seq) | |||||
self.seq = seq | |||||
def forward(self, _): | |||||
return self.seq(None) | |||||
def _check_params(self, prep_d, layer_dimensions, rel_activation, | |||||
layer_activation, dec_activation): | |||||
if not isinstance(prep_d, PreparedData): | |||||
raise TypeError('prep_d must be an instanced of PreparedData') | |||||
if not isinstance(layer_dimensions, list): | |||||
raise TypeError('layer_dimensions must be a list') | |||||
if not isinstance(rel_activation, types.FunctionType): | |||||
raise TypeError('rel_activation must be a function') | |||||
if not isinstance(layer_activation, types.FunctionType): | |||||
raise TypeError('layer_activation must be a function') | |||||
if not isinstance(dec_activation, types.FunctionType): | |||||
raise TypeError('dec_activation must be a function') |
@@ -0,0 +1,79 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
from typing import Union, \ | |||||
List | |||||
from .data import Data | |||||
class InputLayer(torch.nn.Module): | |||||
def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, | |||||
**kwargs) -> None: | |||||
output_dim = output_dim or \ | |||||
list(map(lambda a: a.count, data.node_types)) | |||||
if not isinstance(output_dim, list): | |||||
output_dim = [output_dim,] * len(data.node_types) | |||||
super().__init__(**kwargs) | |||||
self.output_dim = output_dim | |||||
self.data = data | |||||
self.is_sparse=False | |||||
self.node_reps = None | |||||
self.build() | |||||
def build(self) -> None: | |||||
self.node_reps = [] | |||||
for i, nt in enumerate(self.data.node_types): | |||||
reps = torch.rand(nt.count, self.output_dim[i]) | |||||
reps = torch.nn.Parameter(reps) | |||||
self.register_parameter('node_reps[%d]' % i, reps) | |||||
self.node_reps.append(reps) | |||||
def forward(self, x) -> List[torch.nn.Parameter]: | |||||
return self.node_reps | |||||
def __repr__(self) -> str: | |||||
s = '' | |||||
s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim | |||||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
for nt in self.data.node_types: | |||||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
return s.strip() | |||||
class OneHotInputLayer(torch.nn.Module): | |||||
def __init__(self, data: Data, **kwargs) -> None: | |||||
output_dim = [ a.count for a in data.node_types ] | |||||
super().__init__(**kwargs) | |||||
self.output_dim = output_dim | |||||
self.data = data | |||||
self.is_sparse=True | |||||
self.node_reps = None | |||||
self.build() | |||||
def build(self) -> None: | |||||
self.node_reps = torch.nn.ParameterList() | |||||
for i, nt in enumerate(self.data.node_types): | |||||
reps = torch.eye(nt.count).to_sparse() | |||||
reps = torch.nn.Parameter(reps, requires_grad=False) | |||||
# self.register_parameter('node_reps[%d]' % i, reps) | |||||
self.node_reps.append(reps) | |||||
def forward(self, x) -> List[torch.nn.Parameter]: | |||||
return self.node_reps | |||||
def __repr__(self) -> str: | |||||
s = '' | |||||
s += 'Icosagon one-hot input layer\n' | |||||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||||
for nt in self.data.node_types: | |||||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||||
return s.strip() |
@@ -0,0 +1,145 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import numpy as np | |||||
import scipy.sparse as sp | |||||
import torch | |||||
def _check_tensor(adj_mat): | |||||
if not isinstance(adj_mat, torch.Tensor): | |||||
raise ValueError('adj_mat must be a torch.Tensor') | |||||
def _check_sparse(adj_mat): | |||||
if not adj_mat.is_sparse: | |||||
raise ValueError('adj_mat must be sparse') | |||||
def _check_dense(adj_mat): | |||||
if adj_mat.is_sparse: | |||||
raise ValueError('adj_mat must be dense') | |||||
def _check_square(adj_mat): | |||||
if len(adj_mat.shape) != 2 or \ | |||||
adj_mat.shape[0] != adj_mat.shape[1]: | |||||
raise ValueError('adj_mat must be a square matrix') | |||||
def _check_2d(adj_mat): | |||||
if len(adj_mat.shape) != 2: | |||||
raise ValueError('adj_mat must be a square matrix') | |||||
def _sparse_coo_tensor(indices, values, size): | |||||
ctor = { torch.float32: torch.sparse.FloatTensor, | |||||
torch.float32: torch.sparse.DoubleTensor, | |||||
torch.uint8: torch.sparse.ByteTensor, | |||||
torch.long: torch.sparse.LongTensor, | |||||
torch.int: torch.sparse.IntTensor, | |||||
torch.short: torch.sparse.ShortTensor, | |||||
torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||||
return ctor(indices, values, size) | |||||
def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_sparse(adj_mat) | |||||
_check_square(adj_mat) | |||||
adj_mat = adj_mat.coalesce() | |||||
indices = adj_mat.indices() | |||||
values = adj_mat.values() | |||||
eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, | |||||
device=adj_mat.device).view(1, -1) | |||||
eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||||
eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, | |||||
device=adj_mat.device) | |||||
indices = torch.cat((indices, eye_indices), 1) | |||||
values = torch.cat((values, eye_values), 0) | |||||
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||||
return adj_mat | |||||
def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_sparse(adj_mat) | |||||
_check_square(adj_mat) | |||||
adj_mat = add_eye_sparse(adj_mat) | |||||
adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat) | |||||
return adj_mat | |||||
def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_dense(adj_mat) | |||||
_check_square(adj_mat) | |||||
adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, | |||||
device=adj_mat.device) | |||||
adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) | |||||
return adj_mat | |||||
def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_square(adj_mat) | |||||
if adj_mat.is_sparse: | |||||
return norm_adj_mat_one_node_type_sparse(adj_mat) | |||||
else: | |||||
return norm_adj_mat_one_node_type_dense(adj_mat) | |||||
def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_sparse(adj_mat) | |||||
_check_2d(adj_mat) | |||||
adj_mat = adj_mat.coalesce() | |||||
indices = adj_mat.indices() | |||||
values = adj_mat.values() | |||||
degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) | |||||
degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) | |||||
degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) | |||||
degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) | |||||
values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) | |||||
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||||
return adj_mat | |||||
def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_dense(adj_mat) | |||||
_check_2d(adj_mat) | |||||
degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32) | |||||
degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32) | |||||
degrees_row = torch.sqrt(degrees_row) | |||||
degrees_col = torch.sqrt(degrees_col) | |||||
adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row | |||||
adj_mat = adj_mat / degrees_col | |||||
return adj_mat | |||||
def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
_check_tensor(adj_mat) | |||||
_check_2d(adj_mat) | |||||
if adj_mat.is_sparse: | |||||
return norm_adj_mat_two_node_types_sparse(adj_mat) | |||||
else: | |||||
return norm_adj_mat_two_node_types_dense(adj_mat) |
@@ -0,0 +1,47 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import numpy as np | |||||
import torch | |||||
import torch.utils.data | |||||
from typing import List, \ | |||||
Union | |||||
def fixed_unigram_candidate_sampler( | |||||
true_classes: Union[np.array, torch.Tensor], | |||||
unigrams: List[Union[int, float]], | |||||
distortion: float = 1.): | |||||
if isinstance(true_classes, torch.Tensor): | |||||
true_classes = true_classes.detach().cpu().numpy() | |||||
if isinstance(unigrams, torch.Tensor): | |||||
unigrams = unigrams.detach().cpu().numpy() | |||||
if len(true_classes.shape) != 2: | |||||
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||||
num_samples = true_classes.shape[0] | |||||
unigrams = np.array(unigrams) | |||||
if distortion != 1.: | |||||
unigrams = unigrams.astype(np.float64) ** distortion | |||||
# print('unigrams:', unigrams) | |||||
indices = np.arange(num_samples) | |||||
result = np.zeros(num_samples, dtype=np.int64) | |||||
while len(indices) > 0: | |||||
# print('len(indices):', len(indices)) | |||||
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||||
candidates = np.array(list(sampler)) | |||||
candidates = np.reshape(candidates, (len(indices), 1)) | |||||
# print('candidates:', candidates) | |||||
# print('true_classes:', true_classes[indices, :]) | |||||
result[indices] = candidates.T | |||||
mask = (candidates == true_classes[indices, :]) | |||||
mask = mask.sum(1).astype(np.bool) | |||||
# print('mask:', mask) | |||||
indices = indices[mask] | |||||
return torch.tensor(result) |
@@ -0,0 +1,215 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
from .sampling import fixed_unigram_candidate_sampler | |||||
import torch | |||||
from dataclasses import dataclass, \ | |||||
field | |||||
from typing import Any, \ | |||||
List, \ | |||||
Tuple, \ | |||||
Dict | |||||
from .data import NodeType, \ | |||||
RelationType, \ | |||||
RelationTypeBase, \ | |||||
RelationFamily, \ | |||||
RelationFamilyBase, \ | |||||
Data | |||||
from collections import defaultdict | |||||
from .normalize import norm_adj_mat_one_node_type, \ | |||||
norm_adj_mat_two_node_types | |||||
import numpy as np | |||||
@dataclass | |||||
class TrainValTest(object): | |||||
train: Any | |||||
val: Any | |||||
test: Any | |||||
@dataclass | |||||
class PreparedRelationType(RelationTypeBase): | |||||
edges_pos: TrainValTest | |||||
edges_neg: TrainValTest | |||||
edges_back_pos: TrainValTest | |||||
edges_back_neg: TrainValTest | |||||
@dataclass | |||||
class PreparedRelationFamily(RelationFamilyBase): | |||||
relation_types: List[PreparedRelationType] | |||||
@dataclass | |||||
class PreparedData(object): | |||||
node_types: List[NodeType] | |||||
relation_families: List[PreparedRelationFamily] | |||||
def _empty_edge_list_tvt() -> TrainValTest: | |||||
return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||||
def train_val_test_split_edges(edges: torch.Tensor, | |||||
ratios: TrainValTest) -> TrainValTest: | |||||
if not isinstance(edges, torch.Tensor): | |||||
raise ValueError('edges must be a torch.Tensor') | |||||
if len(edges.shape) != 2 or edges.shape[1] != 2: | |||||
raise ValueError('edges shape must be (num_edges, 2)') | |||||
if not isinstance(ratios, TrainValTest): | |||||
raise ValueError('ratios must be a TrainValTest') | |||||
if ratios.train + ratios.val + ratios.test != 1.0: | |||||
raise ValueError('Train, validation and test ratios must add up to 1') | |||||
order = torch.randperm(len(edges)) | |||||
edges = edges[order, :] | |||||
n = round(len(edges) * ratios.train) | |||||
edges_train = edges[:n] | |||||
n_1 = round(len(edges) * (ratios.train + ratios.val)) | |||||
edges_val = edges[n:n_1] | |||||
edges_test = edges[n_1:] | |||||
return TrainValTest(edges_train, edges_val, edges_test) | |||||
def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
if adj_mat.is_sparse: | |||||
adj_mat = adj_mat.coalesce() | |||||
degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, | |||||
device=adj_mat.device) | |||||
degrees = degrees.index_add(0, adj_mat.indices()[1], | |||||
torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, | |||||
device=adj_mat.device)) | |||||
edges_pos = adj_mat.indices().transpose(0, 1) | |||||
else: | |||||
degrees = adj_mat.sum(0) | |||||
edges_pos = torch.nonzero(adj_mat) | |||||
return edges_pos, degrees | |||||
def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: | |||||
if not isinstance(adj_mat, torch.Tensor): | |||||
raise ValueError('adj_mat must be a torch.Tensor') | |||||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||||
neg_neighbors = fixed_unigram_candidate_sampler( | |||||
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||||
print(edges_pos.dtype) | |||||
print(neg_neighbors.dtype) | |||||
edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) | |||||
edges_pos = train_val_test_split_edges(edges_pos, ratios) | |||||
edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||||
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | |||||
values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, | |||||
device=adj_mat.device) | |||||
return adj_mat_train, edges_pos, edges_neg | |||||
def prep_rel_one_node_type(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | |||||
adj_mat = r.adjacency_matrix | |||||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
print('adj_mat_train:', adj_mat_train) | |||||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||||
edges_back_pos, edges_back_neg) | |||||
def prep_rel_two_node_types_sym(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | |||||
adj_mat = r.adjacency_matrix | |||||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
edges_back_pos, edges_back_neg = \ | |||||
_empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
return PreparedRelationType(r.name, r.node_type_row, | |||||
r.node_type_column, | |||||
norm_adj_mat_two_node_types(adj_mat_train), | |||||
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||||
edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
def prep_rel_two_node_types_asym(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | |||||
if r.adjacency_matrix is not None: | |||||
adj_mat_train, edges_pos, edges_neg =\ | |||||
prepare_adj_mat(r.adjacency_matrix, ratios) | |||||
else: | |||||
adj_mat_train, edges_pos, edges_neg = \ | |||||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
if r.adjacency_matrix_backward is not None: | |||||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||||
else: | |||||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
return PreparedRelationType(r.name, r.node_type_row, | |||||
r.node_type_column, | |||||
norm_adj_mat_two_node_types(adj_mat_train), | |||||
norm_adj_mat_two_node_types(adj_mat_back_train), | |||||
edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
def prepare_relation_type(r: RelationType, | |||||
ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||||
if not isinstance(r, RelationType): | |||||
raise ValueError('r must be a RelationType') | |||||
if not isinstance(ratios, TrainValTest): | |||||
raise ValueError('ratios must be a TrainValTest') | |||||
if r.node_type_row == r.node_type_column: | |||||
return prep_rel_one_node_type(r, ratios) | |||||
elif is_symmetric: | |||||
return prep_rel_two_node_types_sym(r, ratios) | |||||
else: | |||||
return prep_rel_two_node_types_asym(r, ratios) | |||||
def prepare_relation_family(fam: RelationFamily, | |||||
ratios: TrainValTest) -> PreparedRelationFamily: | |||||
relation_types = [] | |||||
for r in fam.relation_types: | |||||
relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||||
return PreparedRelationFamily(fam.data, fam.name, | |||||
fam.node_type_row, fam.node_type_column, | |||||
fam.is_symmetric, fam.decoder_class, | |||||
relation_types) | |||||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
if not isinstance(data, Data): | |||||
raise ValueError('data must be of class Data') | |||||
relation_families = [ prepare_relation_family(fam, ratios) \ | |||||
for fam in data.relation_families ] | |||||
return PreparedData(data.node_types, relation_families) |
@@ -0,0 +1,19 @@ | |||||
# | |||||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||||
# License: GPLv3 | |||||
# | |||||
import torch | |||||
import numpy as np | |||||
def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||||
"""Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||||
initialization. | |||||
""" | |||||
init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||||
initial = -init_range + 2 * init_range * \ | |||||
torch.rand(( in_channels, out_channels ), dtype=dtype) | |||||
initial = initial.requires_grad_(True) | |||||
return initial |