@@ -0,0 +1,209 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
from collections import defaultdict | |||
from dataclasses import dataclass, field | |||
import torch | |||
from typing import List, \ | |||
Dict, \ | |||
Tuple, \ | |||
Any, \ | |||
Type | |||
from .decode import DEDICOMDecoder, \ | |||
BilinearDecoder | |||
import numpy as np | |||
def _equal(x: torch.Tensor, y: torch.Tensor): | |||
if x.is_sparse ^ y.is_sparse: | |||
raise ValueError('Cannot mix sparse and dense tensors') | |||
if not x.is_sparse: | |||
return (x == y) | |||
return ((x - y).coalesce().values() == 0) | |||
@dataclass | |||
class NodeType(object): | |||
name: str | |||
count: int | |||
@dataclass | |||
class RelationTypeBase(object): | |||
name: str | |||
node_type_row: int | |||
node_type_column: int | |||
adjacency_matrix: torch.Tensor | |||
adjacency_matrix_backward: torch.Tensor | |||
@dataclass | |||
class RelationType(RelationTypeBase): | |||
pass | |||
@dataclass | |||
class RelationFamilyBase(object): | |||
data: 'Data' | |||
name: str | |||
node_type_row: int | |||
node_type_column: int | |||
is_symmetric: bool | |||
decoder_class: Type | |||
@dataclass | |||
class RelationFamily(RelationFamilyBase): | |||
relation_types: List[RelationType] = None | |||
def __post_init__(self) -> None: | |||
if not self.is_symmetric and \ | |||
self.decoder_class != DEDICOMDecoder and \ | |||
self.decoder_class != BilinearDecoder: | |||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
self.relation_types = [] | |||
def add_relation_type(self, | |||
name: str, adjacency_matrix: torch.Tensor, | |||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
name = str(name) | |||
node_type_row = self.node_type_row | |||
node_type_column = self.node_type_column | |||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
if adjacency_matrix is not None and \ | |||
not isinstance(adjacency_matrix, torch.Tensor): | |||
raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
if adjacency_matrix_backward is not None \ | |||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
if adjacency_matrix is not None and \ | |||
adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
self.data.node_types[node_type_column].count): | |||
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||
if adjacency_matrix_backward is not None and \ | |||
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||
self.data.node_types[node_type_row].count): | |||
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||
if node_type_row == node_type_column and \ | |||
adjacency_matrix_backward is not None: | |||
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||
if self.is_symmetric and adjacency_matrix_backward is not None: | |||
raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') | |||
if self.is_symmetric and node_type_row == node_type_column and \ | |||
not torch.all(_equal(adjacency_matrix, | |||
adjacency_matrix.transpose(0, 1))): | |||
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||
if not self.is_symmetric and node_type_row != node_type_column and \ | |||
adjacency_matrix_backward is None: | |||
raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None') | |||
if self.is_symmetric and node_type_row != node_type_column: | |||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
self.relation_types.append(RelationType(name, | |||
node_type_row, node_type_column, | |||
adjacency_matrix, adjacency_matrix_backward)) | |||
def node_name(self, index): | |||
return self.data.node_types[index].name | |||
def __repr__(self): | |||
s = 'Relation family %s' % self.name | |||
for r in self.relation_types: | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.node_type_row == self.node_type_column \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
return s | |||
def repr_indented(self): | |||
s = ' - %s' % self.name | |||
for r in self.relation_types: | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.node_type_row == self.node_type_column \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
return s | |||
class Data(object): | |||
node_types: List[NodeType] | |||
relation_families: List[RelationFamily] | |||
def __init__(self) -> None: | |||
self.node_types = [] | |||
self.relation_families = [] | |||
def add_node_type(self, name: str, count: int) -> None: | |||
name = str(name) | |||
count = int(count) | |||
if not name: | |||
raise ValueError('You must provide a non-empty node type name') | |||
if count <= 0: | |||
raise ValueError('You must provide a positive node count') | |||
self.node_types.append(NodeType(name, count)) | |||
def add_relation_family(self, name: str, node_type_row: int, | |||
node_type_column: int, is_symmetric: bool, | |||
decoder_class: Type = DEDICOMDecoder): | |||
name = str(name) | |||
node_type_row = int(node_type_row) | |||
node_type_column = int(node_type_column) | |||
is_symmetric = bool(is_symmetric) | |||
if node_type_row < 0 or node_type_row >= len(self.node_types): | |||
raise ValueError('node_type_row outside of the valid range of node types') | |||
if node_type_column < 0 or node_type_column >= len(self.node_types): | |||
raise ValueError('node_type_column outside of the valid range of node types') | |||
fam = RelationFamily(self, name, node_type_row, node_type_column, | |||
is_symmetric, decoder_class) | |||
self.relation_families.append(fam) | |||
return fam | |||
def __repr__(self): | |||
n = len(self.node_types) | |||
if n == 0: | |||
return 'Empty Icosagon Data' | |||
s = '' | |||
s += 'Icosagon Data with:\n' | |||
s += '- ' + str(n) + ' node type(s):\n' | |||
for nt in self.node_types: | |||
s += ' - ' + nt.name + '\n' | |||
if len(self.relation_families) == 0: | |||
s += '- No relation families\n' | |||
return s.strip() | |||
s += '- %d relation families:\n' % len(self.relation_families) | |||
for fam in self.relation_families: | |||
s += fam.repr_indented() + '\n' | |||
return s.strip() |
@@ -0,0 +1,123 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import torch | |||
from .weights import init_glorot | |||
from .dropout import dropout | |||
class DEDICOMDecoder(torch.nn.Module): | |||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
activation=torch.sigmoid, **kwargs): | |||
super().__init__(**kwargs) | |||
self.input_dim = input_dim | |||
self.num_relation_types = num_relation_types | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim)) | |||
self.local_variation = torch.nn.ParameterList([ | |||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
for _ in range(num_relation_types) | |||
]) | |||
def forward(self, inputs_row, inputs_col, relation_index): | |||
inputs_row = dropout(inputs_row, self.keep_prob) | |||
inputs_col = dropout(inputs_col, self.keep_prob) | |||
relation = torch.diag(self.local_variation[relation_index]) | |||
product1 = torch.mm(inputs_row, relation) | |||
product2 = torch.mm(product1, self.global_interaction) | |||
product3 = torch.mm(product2, relation) | |||
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
rec = torch.flatten(rec) | |||
return self.activation(rec) | |||
class DistMultDecoder(torch.nn.Module): | |||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
activation=torch.sigmoid, **kwargs): | |||
super().__init__(**kwargs) | |||
self.input_dim = input_dim | |||
self.num_relation_types = num_relation_types | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
self.relation = torch.nn.ParameterList([ | |||
torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \ | |||
for _ in range(num_relation_types) | |||
]) | |||
def forward(self, inputs_row, inputs_col, relation_index): | |||
inputs_row = dropout(inputs_row, self.keep_prob) | |||
inputs_col = dropout(inputs_col, self.keep_prob) | |||
relation = torch.diag(self.relation[relation_index]) | |||
intermediate_product = torch.mm(inputs_row, relation) | |||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
rec = torch.flatten(rec) | |||
return self.activation(rec) | |||
class BilinearDecoder(torch.nn.Module): | |||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
activation=torch.sigmoid, **kwargs): | |||
super().__init__(**kwargs) | |||
self.input_dim = input_dim | |||
self.num_relation_types = num_relation_types | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
self.relation = torch.nn.ParameterList([ | |||
torch.nn.Parameter(init_glorot(input_dim, input_dim)) \ | |||
for _ in range(num_relation_types) | |||
]) | |||
def forward(self, inputs_row, inputs_col, relation_index): | |||
inputs_row = dropout(inputs_row, self.keep_prob) | |||
inputs_col = dropout(inputs_col, self.keep_prob) | |||
intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) | |||
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
rec = torch.flatten(rec) | |||
return self.activation(rec) | |||
class InnerProductDecoder(torch.nn.Module): | |||
"""DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||
def __init__(self, input_dim, num_relation_types, keep_prob=1., | |||
activation=torch.sigmoid, **kwargs): | |||
super().__init__(**kwargs) | |||
self.input_dim = input_dim | |||
self.num_relation_types = num_relation_types | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
def forward(self, inputs_row, inputs_col, _): | |||
inputs_row = dropout(inputs_row, self.keep_prob) | |||
inputs_col = dropout(inputs_col, self.keep_prob) | |||
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||
rec = torch.flatten(rec) | |||
return self.activation(rec) |
@@ -0,0 +1,42 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import torch | |||
from .normalize import _sparse_coo_tensor | |||
def dropout_sparse(x, keep_prob): | |||
x = x.coalesce() | |||
i = x._indices() | |||
v = x._values() | |||
size = x.size() | |||
n = keep_prob + torch.rand(len(v)) | |||
n = torch.floor(n).to(torch.bool) | |||
i = i[:,n] | |||
v = v[n] | |||
x = _sparse_coo_tensor(i, v, size=size) | |||
return x * (1./keep_prob) | |||
def dropout_dense(x, keep_prob): | |||
# print('dropout_dense()') | |||
x = x.clone() | |||
i = torch.nonzero(x) | |||
n = keep_prob + torch.rand(len(i)) | |||
n = (1. - torch.floor(n)).to(torch.bool) | |||
x[i[n, 0], i[n, 1]] = 0. | |||
return x * (1./keep_prob) | |||
def dropout(x, keep_prob): | |||
if x.is_sparse: | |||
return dropout_sparse(x, keep_prob) | |||
else: | |||
return dropout_dense(x, keep_prob) |
@@ -0,0 +1,255 @@ | |||
from typing import List, \ | |||
Union, \ | |||
Callable | |||
from .data import Data, \ | |||
RelationFamily | |||
from .trainprep import PreparedData, \ | |||
PreparedRelationFamily | |||
import torch | |||
from .weights import init_glorot | |||
from .normalize import _sparse_coo_tensor | |||
import types | |||
def _sparse_diag_cat(matrices: List[torch.Tensor]): | |||
if len(matrices) == 0: | |||
raise ValueError('The list of matrices must be non-empty') | |||
if not all(m.is_sparse for m in matrices): | |||
raise ValueError('All matrices must be sparse') | |||
if not all(len(m.shape) == 2 for m in matrices): | |||
raise ValueError('All matrices must be 2D') | |||
indices = [] | |||
values = [] | |||
row_offset = 0 | |||
col_offset = 0 | |||
for m in matrices: | |||
ind = m._indices().clone() | |||
ind[0] += row_offset | |||
ind[1] += col_offset | |||
indices.append(ind) | |||
values.append(m._values()) | |||
row_offset += m.shape[0] | |||
col_offset += m.shape[1] | |||
indices = torch.cat(indices, dim=1) | |||
values = torch.cat(values) | |||
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) | |||
def _cat(matrices: List[torch.Tensor]): | |||
if len(matrices) == 0: | |||
raise ValueError('Empty list passed to _cat()') | |||
n = sum(a.is_sparse for a in matrices) | |||
if n != 0 and n != len(matrices): | |||
raise ValueError('All matrices must have the same layout (dense or sparse)') | |||
if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): | |||
raise ValueError('All matrices must have the same dimensions apart from dimension 0') | |||
if not matrices[0].is_sparse: | |||
return torch.cat(matrices) | |||
total_rows = sum(a.shape[0] for a in matrices) | |||
indices = [] | |||
values = [] | |||
row_offset = 0 | |||
for a in matrices: | |||
ind = a._indices().clone() | |||
val = a._values() | |||
ind[0] += row_offset | |||
ind = ind.transpose(0, 1) | |||
indices.append(ind) | |||
values.append(val) | |||
row_offset += a.shape[0] | |||
indices = torch.cat(indices).transpose(0, 1) | |||
values = torch.cat(values) | |||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
return res | |||
class FastGraphConv(torch.nn.Module): | |||
def __init__(self, | |||
in_channels: int, | |||
out_channels: int, | |||
adjacency_matrices: List[torch.Tensor], | |||
keep_prob: float = 1., | |||
activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
**kwargs) -> None: | |||
super().__init__(**kwargs) | |||
in_channels = int(in_channels) | |||
out_channels = int(out_channels) | |||
if not isinstance(adjacency_matrices, list): | |||
raise TypeError('adjacency_matrices must be a list') | |||
if len(adjacency_matrices) == 0: | |||
raise ValueError('adjacency_matrices must not be empty') | |||
if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices): | |||
raise TypeError('adjacency_matrices elements must be of class torch.Tensor') | |||
if not all(m.is_sparse for m in adjacency_matrices): | |||
raise ValueError('adjacency_matrices elements must be sparse') | |||
keep_prob = float(keep_prob) | |||
if not isinstance(activation, types.FunctionType): | |||
raise TypeError('activation must be a function') | |||
self.in_channels = in_channels | |||
self.out_channels = out_channels | |||
self.adjacency_matrices = adjacency_matrices | |||
self.keep_prob = keep_prob | |||
self.activation = activation | |||
self.num_row_nodes = len(adjacency_matrices[0]) | |||
self.num_relation_types = len(adjacency_matrices) | |||
self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices) | |||
self.weights = torch.cat([ | |||
init_glorot(in_channels, out_channels) \ | |||
for _ in range(self.num_relation_types) | |||
], dim=1) | |||
def forward(self, x) -> torch.Tensor: | |||
if self.keep_prob < 1.: | |||
x = dropout(x, self.keep_prob) | |||
res = torch.sparse.mm(x, self.weights) \ | |||
if x.is_sparse \ | |||
else torch.mm(x, self.weights) | |||
res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1) | |||
res = torch.cat(res) | |||
res = torch.sparse.mm(self.adjacency_matrices, res) \ | |||
if self.adjacency_matrices.is_sparse \ | |||
else torch.mm(self.adjacency_matrices, res) | |||
res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels) | |||
if self.activation is not None: | |||
res = self.activation(res) | |||
return res | |||
class FastConvLayer(torch.nn.Module): | |||
def __init__(self, | |||
input_dim: List[int], | |||
output_dim: List[int], | |||
data: Union[Data, PreparedData], | |||
keep_prob: float = 1., | |||
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
**kwargs): | |||
super().__init__(**kwargs) | |||
self._check_params(input_dim, output_dim, data, keep_prob, | |||
rel_activation, layer_activation) | |||
self.input_dim = input_dim | |||
self.output_dim = output_dim | |||
self.data = data | |||
self.keep_prob = keep_prob | |||
self.rel_activation = rel_activation | |||
self.layer_activation = layer_activation | |||
self.is_sparse = False | |||
self.next_layer_repr = None | |||
self.build() | |||
def build(self): | |||
self.next_layer_repr = torch.nn.ModuleList([ | |||
torch.nn.ModuleList() \ | |||
for _ in range(len(self.data.node_types)) | |||
]) | |||
for fam in self.data.relation_families: | |||
self.build_family(fam) | |||
def build_family(self, fam) -> None: | |||
if fam.node_type_row == fam.node_type_column: | |||
self.build_fam_one_node_type(fam) | |||
else: | |||
self.build_fam_two_node_types(fam) | |||
def build_fam_one_node_type(self, fam) -> None: | |||
adjacency_matrices = [ | |||
r.adjacency_matrix \ | |||
for r in fam.relation_types | |||
] | |||
conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
self.output_dim[fam.node_type_row], | |||
adjacency_matrices, | |||
self.keep_prob, | |||
self.rel_activation) | |||
conv.input_node_type = fam.node_type_column | |||
self.next_layer_repr[fam.node_type_row].append(conv) | |||
def build_fam_two_node_types(self, fam) -> None: | |||
adjacency_matrices = [ | |||
r.adjacency_matrix \ | |||
for r in fam.relation_types \ | |||
if r.adjacency_matrix is not None | |||
] | |||
adjacency_matrices_backward = [ | |||
r.adjacency_matrix_backward \ | |||
for r in fam.relation_types \ | |||
if r.adjacency_matrix_backward is not None | |||
] | |||
conv = FastGraphConv(self.input_dim[fam.node_type_column], | |||
self.output_dim[fam.node_type_row], | |||
adjacency_matrices, | |||
self.keep_prob, | |||
self.rel_activation) | |||
conv_backward = FastGraphConv(self.input_dim[fam.node_type_row], | |||
self.output_dim[fam.node_type_column], | |||
adjacency_matrices_backward, | |||
self.keep_prob, | |||
self.rel_activation) | |||
conv.input_node_type = fam.node_type_column | |||
conv_backward.input_node_type = fam.node_type_row | |||
self.next_layer_repr[fam.node_type_row].append(conv) | |||
self.next_layer_repr[fam.node_type_column].append(conv_backward) | |||
def forward(self, prev_layer_repr): | |||
next_layer_repr = [ [] \ | |||
for _ in range(len(self.data.node_types)) ] | |||
for output_node_type in range(len(self.data.node_types)): | |||
for conv in self.next_layer_repr[output_node_type]: | |||
rep = conv(prev_layer_repr[conv.input_node_type]) | |||
rep = torch.sum(rep, dim=0) | |||
rep = torch.nn.functional.normalize(rep, p=2, dim=1) | |||
next_layer_repr[output_node_type].append(rep) | |||
if len(next_layer_repr[output_node_type]) == 0: | |||
next_layer_repr[output_node_type] = \ | |||
torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type]) | |||
else: | |||
next_layer_repr[output_node_type] = \ | |||
sum(next_layer_repr[output_node_type]) | |||
next_layer_repr[output_node_type] = \ | |||
self.layer_activation(next_layer_repr[output_node_type]) | |||
return next_layer_repr | |||
@staticmethod | |||
def _check_params(input_dim, output_dim, data, keep_prob, | |||
rel_activation, layer_activation): | |||
if not isinstance(input_dim, list): | |||
raise ValueError('input_dim must be a list') | |||
if not output_dim: | |||
raise ValueError('output_dim must be specified') | |||
if not isinstance(output_dim, list): | |||
output_dim = [output_dim] * len(data.node_types) | |||
if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||
raise ValueError('data must be of type Data or PreparedData') |
@@ -0,0 +1,138 @@ | |||
import torch | |||
from typing import List | |||
from .trainprep import PreparedData | |||
from dataclasses import dataclass | |||
import random | |||
from collections import defaultdict | |||
@dataclass | |||
class TrainingBatch(object): | |||
relation_family_index: int | |||
relation_type_index: int | |||
node_type_row: int | |||
node_type_column: int | |||
edges: torch.Tensor | |||
class FastBatcher(object): | |||
def __init__(self, | |||
prep_d: PreparedData, | |||
batch_size: int) -> None: | |||
if not isinstance(prep_d, PreparedData): | |||
raise TypeError('prep_d must be an instance of PreparedData') | |||
self.prep_d = prep_d | |||
self.batch_size = int(batch_size) | |||
self.edges = None | |||
self.build() | |||
def build(self): | |||
self.edges = [] | |||
for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||
edges = [] | |||
targets = [] | |||
edges_back = [] | |||
targets_back = [] | |||
for rel_idx, rel in enumerate(fam.relation_types): | |||
edges.append(rel.edges_pos.train) | |||
edges.append(rel.edges_neg.train) | |||
targets.append(torch.ones(len(rel.edges_pos.train))) | |||
targets.append(torch.zeros(len(rel.edges_neg.train))) | |||
edges_back.append(rel.edges_back_pos.train) | |||
edges_back.append(rel.edges_back_neg.train) | |||
targets_back.apend(torch.zeros(len(rel.edges_back_pos.train))) | |||
targets_back.apend(torch.zeros(len(rel.edges_back_neg.train))) | |||
edges = torch.cat(edges) | |||
targets = torch.cat(targets) | |||
edges_back = torch.cat(edges_back) | |||
targets_back = torch.cat(targets_back) | |||
order = torch.randperm(len(edges)) | |||
edges = edges[order] | |||
targets = targets[order] | |||
order_back = torch.randperm(len(edges_back)) | |||
edges_back = edges_back[order_back] | |||
targets_back = targets_back[order_back] | |||
self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False, | |||
'edges': edges, 'targets': targets, 'ofs': 0}) | |||
self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True, | |||
'edges': edges_back, 'targets': targets_back, 'ofs': 0}) | |||
def __iter__(self): | |||
while True: | |||
edges = [ e for e in self.edges \ | |||
if e['ofs'] < len(e['edges']) ] | |||
# TODO: need to finish this | |||
def __iter_old__(self): | |||
edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg'] | |||
offsets = {} | |||
orders = {} | |||
done = {} | |||
for fam_idx, fam in enumerate(self.prep_d.relation_families): | |||
for rel_idx, rel in enumerate(fam.relation_types): | |||
for et in edge_types: | |||
done[fam_idx, rel_idx, et] = False | |||
while True: | |||
fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item() | |||
fam = self.prep_d.relation_families[fam_idx] | |||
rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item() | |||
rel = fam.relation_types[rel_idx] | |||
et = random.choice(edge_types) | |||
edges = getattr(rel, et).train | |||
key = (fam_idx, rel_idx, et) | |||
if key not in orders: | |||
orders[key] = torch.randperm(len(edges)) | |||
offsets[key] = 0 | |||
ord = orders[key] | |||
ofs = offsets[key] | |||
nt_row = rel.node_type_row | |||
nt_col = rel.node_type_column | |||
if 'back' in et: | |||
nt_row, nt_col = nt_col, nt_row | |||
if ofs < len(edges): | |||
offsets[key] += self.batch_size | |||
ord = ord[ofs:ofs+self.batch_size] | |||
edges = edges[ord] | |||
yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges) | |||
else: | |||
done[key] = True | |||
for fam in self.prep_d.relation_families: | |||
edges = [] | |||
for rel in fam.relation_types: | |||
edges.append(rel.edges_pos.train) | |||
edges.append(rel.edges_back_pos.train) | |||
edges.append(rel.edges_neg.train) | |||
edges.append(rel.edges_back_neg.train) | |||
edges = torch.cat(e) | |||
class FastDecLayer(torch.nn.Module): | |||
def __init__(self, **kwargs): | |||
super().__init__(**kwargs) | |||
def forward(self, | |||
last_layer_repr: List[torch.Tensor], | |||
training_batch: TrainingBatch): |
@@ -0,0 +1,166 @@ | |||
from .fastmodel import FastModel | |||
from .trainprep import PreparedData | |||
import torch | |||
from typing import Callable | |||
from types import FunctionType | |||
import time | |||
import random | |||
class FastBatcher(object): | |||
def __init__(self, prep_d: PreparedData, batch_size: int, | |||
shuffle: bool, generator: torch.Generator, | |||
part_type: str) -> None: | |||
if not isinstance(prep_d, PreparedData): | |||
raise TypeError('prep_d must be an instance of PreparedData') | |||
if not isinstance(generator, torch.Generator): | |||
raise TypeError('generator must be an instance of torch.Generator') | |||
if part_type not in ['train', 'val', 'test']: | |||
raise ValueError('part_type must be set to train, val or test') | |||
self.prep_d = prep_d | |||
self.batch_size = int(batch_size) | |||
self.shuffle = bool(shuffle) | |||
self.generator = generator | |||
self.part_type = part_type | |||
self.edges = None | |||
self.targets = None | |||
self.build() | |||
def build(self): | |||
self.edges = [] | |||
self.targets = [] | |||
for fam in self.prep_d.relation_families: | |||
edges = [] | |||
targets = [] | |||
for i, rel in enumerate(fam.relation_types): | |||
edges_pos = getattr(rel.edges_pos, self.part_type) | |||
edges_neg = getattr(rel.edges_neg, self.part_type) | |||
edges_back_pos = getattr(rel.edges_back_pos, self.part_type) | |||
edges_back_neg = getattr(rel.edges_back_neg, self.part_type) | |||
e = torch.cat([ edges_pos, | |||
torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ]) | |||
e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1) | |||
t = torch.ones(len(e)) | |||
edges.append(e) | |||
targets.append(t) | |||
e = torch.cat([ edges_neg, | |||
torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ]) | |||
e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1) | |||
t = torch.zeros(len(e)) | |||
edges.append(e) | |||
targets.append(t) | |||
edges = torch.cat(edges) | |||
targets = torch.cat(targets) | |||
self.edges.append(edges) | |||
self.targets.append(targets) | |||
# print(self.edges) | |||
# print(self.targets) | |||
if self.shuffle: | |||
self.shuffle_families() | |||
def shuffle_families(self): | |||
for i in range(len(self.edges)): | |||
edges = self.edges[i] | |||
targets = self.targets[i] | |||
order = torch.randperm(len(edges), generator=self.generator) | |||
self.edges[i] = edges[order] | |||
self.targets[i] = targets[order] | |||
def __iter__(self): | |||
offsets = [ 0 for _ in self.edges ] | |||
while True: | |||
choice = [ i for i in range(len(offsets)) \ | |||
if offsets[i] < len(self.edges[i]) ] | |||
if len(choice) == 0: | |||
break | |||
fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item() | |||
ofs = offsets[fam_idx] | |||
edges = self.edges[fam_idx][ofs:ofs + self.batch_size] | |||
targets = self.targets[fam_idx][ofs:ofs + self.batch_size] | |||
offsets[fam_idx] += self.batch_size | |||
yield (fam_idx, edges, targets) | |||
class FastLoop(object): | |||
def __init__( | |||
self, | |||
model: FastModel, | |||
lr: float = 0.001, | |||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \ | |||
torch.nn.functional.binary_cross_entropy_with_logits, | |||
batch_size: int = 100, | |||
shuffle: bool = True, | |||
generator: torch.Generator = None) -> None: | |||
self._check_params(model, loss, generator) | |||
self.model = model | |||
self.lr = float(lr) | |||
self.loss = loss | |||
self.batch_size = int(batch_size) | |||
self.shuffle = bool(shuffle) | |||
self.generator = generator or torch.default_generator | |||
self.opt = None | |||
self.build() | |||
def _check_params(self, model, loss, generator): | |||
if not isinstance(model, FastModel): | |||
raise TypeError('model must be an instance of FastModel') | |||
if not isinstance(loss, FunctionType): | |||
raise TypeError('loss must be a function') | |||
if generator is not None and not isinstance(generator, torch.Generator): | |||
raise TypeError('generator must be an instance of torch.Generator') | |||
def build(self) -> None: | |||
opt = torch.optim.Adam(self.model.parameters(), lr=self.lr) | |||
self.opt = opt | |||
def run_epoch(self): | |||
prep_d = self.model.prep_d | |||
batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size, | |||
shuffle = self.shuffle, generator=self.generator) | |||
# pred = self.model(None) | |||
# n = len(list(iter(batch))) | |||
loss_sum = 0 | |||
for fam_idx, edges, targets in batcher: | |||
self.opt.zero_grad() | |||
pred = self.model(None) | |||
# process pred, get input and targets | |||
input = pred[fam_idx][edges[:, 0], edges[:, 1]] | |||
loss = self.loss(input, targets) | |||
loss.backward() | |||
self.opt.step() | |||
loss_sum += loss.detach().cpu().item() | |||
return loss_sum | |||
def train(self, max_epochs): | |||
best_loss = None | |||
best_epoch = None | |||
for i in range(max_epochs): | |||
loss = self.run_epoch() | |||
if best_loss is None or loss < best_loss: | |||
best_loss = loss | |||
best_epoch = i | |||
return loss, best_loss, best_epoch |
@@ -0,0 +1,79 @@ | |||
from .fastconv import FastConvLayer | |||
from .bulkdec import BulkDecodeLayer | |||
from .input import OneHotInputLayer | |||
from .trainprep import PreparedData | |||
import torch | |||
import types | |||
from typing import List, \ | |||
Union, \ | |||
Callable | |||
class FastModel(torch.nn.Module): | |||
def __init__(self, prep_d: PreparedData, | |||
layer_dimensions: List[int] = [32, 64], | |||
keep_prob: float = 1., | |||
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||
dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||
**kwargs) -> None: | |||
super().__init__(**kwargs) | |||
self._check_params(prep_d, layer_dimensions, rel_activation, | |||
layer_activation, dec_activation) | |||
self.prep_d = prep_d | |||
self.layer_dimensions = layer_dimensions | |||
self.keep_prob = float(keep_prob) | |||
self.rel_activation = rel_activation | |||
self.layer_activation = layer_activation | |||
self.dec_activation = dec_activation | |||
self.seq = None | |||
self.build() | |||
def build(self): | |||
in_layer = OneHotInputLayer(self.prep_d) | |||
last_output_dim = in_layer.output_dim | |||
seq = [ in_layer ] | |||
for dim in self.layer_dimensions: | |||
conv_layer = FastConvLayer(input_dim = last_output_dim, | |||
output_dim = [dim] * len(self.prep_d.node_types), | |||
data = self.prep_d, | |||
keep_prob = self.keep_prob, | |||
rel_activation = self.rel_activation, | |||
layer_activation = self.layer_activation) | |||
last_output_dim = conv_layer.output_dim | |||
seq.append(conv_layer) | |||
dec_layer = BulkDecodeLayer(input_dim = last_output_dim, | |||
data = self.prep_d, | |||
keep_prob = self.keep_prob, | |||
activation = self.dec_activation) | |||
seq.append(dec_layer) | |||
seq = torch.nn.Sequential(*seq) | |||
self.seq = seq | |||
def forward(self, _): | |||
return self.seq(None) | |||
def _check_params(self, prep_d, layer_dimensions, rel_activation, | |||
layer_activation, dec_activation): | |||
if not isinstance(prep_d, PreparedData): | |||
raise TypeError('prep_d must be an instanced of PreparedData') | |||
if not isinstance(layer_dimensions, list): | |||
raise TypeError('layer_dimensions must be a list') | |||
if not isinstance(rel_activation, types.FunctionType): | |||
raise TypeError('rel_activation must be a function') | |||
if not isinstance(layer_activation, types.FunctionType): | |||
raise TypeError('layer_activation must be a function') | |||
if not isinstance(dec_activation, types.FunctionType): | |||
raise TypeError('dec_activation must be a function') |
@@ -0,0 +1,79 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import torch | |||
from typing import Union, \ | |||
List | |||
from .data import Data | |||
class InputLayer(torch.nn.Module): | |||
def __init__(self, data: Data, output_dim: Union[int, List[int]] = None, | |||
**kwargs) -> None: | |||
output_dim = output_dim or \ | |||
list(map(lambda a: a.count, data.node_types)) | |||
if not isinstance(output_dim, list): | |||
output_dim = [output_dim,] * len(data.node_types) | |||
super().__init__(**kwargs) | |||
self.output_dim = output_dim | |||
self.data = data | |||
self.is_sparse=False | |||
self.node_reps = None | |||
self.build() | |||
def build(self) -> None: | |||
self.node_reps = [] | |||
for i, nt in enumerate(self.data.node_types): | |||
reps = torch.rand(nt.count, self.output_dim[i]) | |||
reps = torch.nn.Parameter(reps) | |||
self.register_parameter('node_reps[%d]' % i, reps) | |||
self.node_reps.append(reps) | |||
def forward(self, x) -> List[torch.nn.Parameter]: | |||
return self.node_reps | |||
def __repr__(self) -> str: | |||
s = '' | |||
s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim | |||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||
for nt in self.data.node_types: | |||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
return s.strip() | |||
class OneHotInputLayer(torch.nn.Module): | |||
def __init__(self, data: Data, **kwargs) -> None: | |||
output_dim = [ a.count for a in data.node_types ] | |||
super().__init__(**kwargs) | |||
self.output_dim = output_dim | |||
self.data = data | |||
self.is_sparse=True | |||
self.node_reps = None | |||
self.build() | |||
def build(self) -> None: | |||
self.node_reps = torch.nn.ParameterList() | |||
for i, nt in enumerate(self.data.node_types): | |||
reps = torch.eye(nt.count).to_sparse() | |||
reps = torch.nn.Parameter(reps, requires_grad=False) | |||
# self.register_parameter('node_reps[%d]' % i, reps) | |||
self.node_reps.append(reps) | |||
def forward(self, x) -> List[torch.nn.Parameter]: | |||
return self.node_reps | |||
def __repr__(self) -> str: | |||
s = '' | |||
s += 'Icosagon one-hot input layer\n' | |||
s += ' # of node types: %d\n' % len(self.data.node_types) | |||
for nt in self.data.node_types: | |||
s += ' - %s (%d)\n' % (nt.name, nt.count) | |||
return s.strip() |
@@ -0,0 +1,145 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import numpy as np | |||
import scipy.sparse as sp | |||
import torch | |||
def _check_tensor(adj_mat): | |||
if not isinstance(adj_mat, torch.Tensor): | |||
raise ValueError('adj_mat must be a torch.Tensor') | |||
def _check_sparse(adj_mat): | |||
if not adj_mat.is_sparse: | |||
raise ValueError('adj_mat must be sparse') | |||
def _check_dense(adj_mat): | |||
if adj_mat.is_sparse: | |||
raise ValueError('adj_mat must be dense') | |||
def _check_square(adj_mat): | |||
if len(adj_mat.shape) != 2 or \ | |||
adj_mat.shape[0] != adj_mat.shape[1]: | |||
raise ValueError('adj_mat must be a square matrix') | |||
def _check_2d(adj_mat): | |||
if len(adj_mat.shape) != 2: | |||
raise ValueError('adj_mat must be a square matrix') | |||
def _sparse_coo_tensor(indices, values, size): | |||
ctor = { torch.float32: torch.sparse.FloatTensor, | |||
torch.float32: torch.sparse.DoubleTensor, | |||
torch.uint8: torch.sparse.ByteTensor, | |||
torch.long: torch.sparse.LongTensor, | |||
torch.int: torch.sparse.IntTensor, | |||
torch.short: torch.sparse.ShortTensor, | |||
torch.bool: torch.sparse.ByteTensor }[values.dtype] | |||
return ctor(indices, values, size) | |||
def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_sparse(adj_mat) | |||
_check_square(adj_mat) | |||
adj_mat = adj_mat.coalesce() | |||
indices = adj_mat.indices() | |||
values = adj_mat.values() | |||
eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, | |||
device=adj_mat.device).view(1, -1) | |||
eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||
eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, | |||
device=adj_mat.device) | |||
indices = torch.cat((indices, eye_indices), 1) | |||
values = torch.cat((values, eye_values), 0) | |||
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||
return adj_mat | |||
def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_sparse(adj_mat) | |||
_check_square(adj_mat) | |||
adj_mat = add_eye_sparse(adj_mat) | |||
adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat) | |||
return adj_mat | |||
def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_dense(adj_mat) | |||
_check_square(adj_mat) | |||
adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, | |||
device=adj_mat.device) | |||
adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) | |||
return adj_mat | |||
def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_square(adj_mat) | |||
if adj_mat.is_sparse: | |||
return norm_adj_mat_one_node_type_sparse(adj_mat) | |||
else: | |||
return norm_adj_mat_one_node_type_dense(adj_mat) | |||
def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_sparse(adj_mat) | |||
_check_2d(adj_mat) | |||
adj_mat = adj_mat.coalesce() | |||
indices = adj_mat.indices() | |||
values = adj_mat.values() | |||
degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) | |||
degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) | |||
degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) | |||
degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) | |||
values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) | |||
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) | |||
return adj_mat | |||
def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_dense(adj_mat) | |||
_check_2d(adj_mat) | |||
degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32) | |||
degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32) | |||
degrees_row = torch.sqrt(degrees_row) | |||
degrees_col = torch.sqrt(degrees_col) | |||
adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row | |||
adj_mat = adj_mat / degrees_col | |||
return adj_mat | |||
def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor: | |||
_check_tensor(adj_mat) | |||
_check_2d(adj_mat) | |||
if adj_mat.is_sparse: | |||
return norm_adj_mat_two_node_types_sparse(adj_mat) | |||
else: | |||
return norm_adj_mat_two_node_types_dense(adj_mat) |
@@ -0,0 +1,47 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import numpy as np | |||
import torch | |||
import torch.utils.data | |||
from typing import List, \ | |||
Union | |||
def fixed_unigram_candidate_sampler( | |||
true_classes: Union[np.array, torch.Tensor], | |||
unigrams: List[Union[int, float]], | |||
distortion: float = 1.): | |||
if isinstance(true_classes, torch.Tensor): | |||
true_classes = true_classes.detach().cpu().numpy() | |||
if isinstance(unigrams, torch.Tensor): | |||
unigrams = unigrams.detach().cpu().numpy() | |||
if len(true_classes.shape) != 2: | |||
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||
num_samples = true_classes.shape[0] | |||
unigrams = np.array(unigrams) | |||
if distortion != 1.: | |||
unigrams = unigrams.astype(np.float64) ** distortion | |||
# print('unigrams:', unigrams) | |||
indices = np.arange(num_samples) | |||
result = np.zeros(num_samples, dtype=np.int64) | |||
while len(indices) > 0: | |||
# print('len(indices):', len(indices)) | |||
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | |||
candidates = np.array(list(sampler)) | |||
candidates = np.reshape(candidates, (len(indices), 1)) | |||
# print('candidates:', candidates) | |||
# print('true_classes:', true_classes[indices, :]) | |||
result[indices] = candidates.T | |||
mask = (candidates == true_classes[indices, :]) | |||
mask = mask.sum(1).astype(np.bool) | |||
# print('mask:', mask) | |||
indices = indices[mask] | |||
return torch.tensor(result) |
@@ -0,0 +1,215 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
from .sampling import fixed_unigram_candidate_sampler | |||
import torch | |||
from dataclasses import dataclass, \ | |||
field | |||
from typing import Any, \ | |||
List, \ | |||
Tuple, \ | |||
Dict | |||
from .data import NodeType, \ | |||
RelationType, \ | |||
RelationTypeBase, \ | |||
RelationFamily, \ | |||
RelationFamilyBase, \ | |||
Data | |||
from collections import defaultdict | |||
from .normalize import norm_adj_mat_one_node_type, \ | |||
norm_adj_mat_two_node_types | |||
import numpy as np | |||
@dataclass | |||
class TrainValTest(object): | |||
train: Any | |||
val: Any | |||
test: Any | |||
@dataclass | |||
class PreparedRelationType(RelationTypeBase): | |||
edges_pos: TrainValTest | |||
edges_neg: TrainValTest | |||
edges_back_pos: TrainValTest | |||
edges_back_neg: TrainValTest | |||
@dataclass | |||
class PreparedRelationFamily(RelationFamilyBase): | |||
relation_types: List[PreparedRelationType] | |||
@dataclass | |||
class PreparedData(object): | |||
node_types: List[NodeType] | |||
relation_families: List[PreparedRelationFamily] | |||
def _empty_edge_list_tvt() -> TrainValTest: | |||
return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||
def train_val_test_split_edges(edges: torch.Tensor, | |||
ratios: TrainValTest) -> TrainValTest: | |||
if not isinstance(edges, torch.Tensor): | |||
raise ValueError('edges must be a torch.Tensor') | |||
if len(edges.shape) != 2 or edges.shape[1] != 2: | |||
raise ValueError('edges shape must be (num_edges, 2)') | |||
if not isinstance(ratios, TrainValTest): | |||
raise ValueError('ratios must be a TrainValTest') | |||
if ratios.train + ratios.val + ratios.test != 1.0: | |||
raise ValueError('Train, validation and test ratios must add up to 1') | |||
order = torch.randperm(len(edges)) | |||
edges = edges[order, :] | |||
n = round(len(edges) * ratios.train) | |||
edges_train = edges[:n] | |||
n_1 = round(len(edges) * (ratios.train + ratios.val)) | |||
edges_val = edges[n:n_1] | |||
edges_test = edges[n_1:] | |||
return TrainValTest(edges_train, edges_val, edges_test) | |||
def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
if adj_mat.is_sparse: | |||
adj_mat = adj_mat.coalesce() | |||
degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, | |||
device=adj_mat.device) | |||
degrees = degrees.index_add(0, adj_mat.indices()[1], | |||
torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, | |||
device=adj_mat.device)) | |||
edges_pos = adj_mat.indices().transpose(0, 1) | |||
else: | |||
degrees = adj_mat.sum(0) | |||
edges_pos = torch.nonzero(adj_mat) | |||
return edges_pos, degrees | |||
def prepare_adj_mat(adj_mat: torch.Tensor, | |||
ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: | |||
if not isinstance(adj_mat, torch.Tensor): | |||
raise ValueError('adj_mat must be a torch.Tensor') | |||
edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
neg_neighbors = fixed_unigram_candidate_sampler( | |||
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||
print(edges_pos.dtype) | |||
print(neg_neighbors.dtype) | |||
edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) | |||
edges_pos = train_val_test_split_edges(edges_pos, ratios) | |||
edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | |||
values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, | |||
device=adj_mat.device) | |||
return adj_mat_train, edges_pos, edges_neg | |||
def prep_rel_one_node_type(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
adj_mat = r.adjacency_matrix | |||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
print('adj_mat_train:', adj_mat_train) | |||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||
edges_back_pos, edges_back_neg) | |||
def prep_rel_two_node_types_sym(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
adj_mat = r.adjacency_matrix | |||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
edges_back_pos, edges_back_neg = \ | |||
_empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
return PreparedRelationType(r.name, r.node_type_row, | |||
r.node_type_column, | |||
norm_adj_mat_two_node_types(adj_mat_train), | |||
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||
edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
def prep_rel_two_node_types_asym(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
if r.adjacency_matrix is not None: | |||
adj_mat_train, edges_pos, edges_neg =\ | |||
prepare_adj_mat(r.adjacency_matrix, ratios) | |||
else: | |||
adj_mat_train, edges_pos, edges_neg = \ | |||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
if r.adjacency_matrix_backward is not None: | |||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||
else: | |||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
return PreparedRelationType(r.name, r.node_type_row, | |||
r.node_type_column, | |||
norm_adj_mat_two_node_types(adj_mat_train), | |||
norm_adj_mat_two_node_types(adj_mat_back_train), | |||
edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
def prepare_relation_type(r: RelationType, | |||
ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||
if not isinstance(r, RelationType): | |||
raise ValueError('r must be a RelationType') | |||
if not isinstance(ratios, TrainValTest): | |||
raise ValueError('ratios must be a TrainValTest') | |||
if r.node_type_row == r.node_type_column: | |||
return prep_rel_one_node_type(r, ratios) | |||
elif is_symmetric: | |||
return prep_rel_two_node_types_sym(r, ratios) | |||
else: | |||
return prep_rel_two_node_types_asym(r, ratios) | |||
def prepare_relation_family(fam: RelationFamily, | |||
ratios: TrainValTest) -> PreparedRelationFamily: | |||
relation_types = [] | |||
for r in fam.relation_types: | |||
relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||
return PreparedRelationFamily(fam.data, fam.name, | |||
fam.node_type_row, fam.node_type_column, | |||
fam.is_symmetric, fam.decoder_class, | |||
relation_types) | |||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
if not isinstance(data, Data): | |||
raise ValueError('data must be of class Data') | |||
relation_families = [ prepare_relation_family(fam, ratios) \ | |||
for fam in data.relation_families ] | |||
return PreparedData(data.node_types, relation_families) |
@@ -0,0 +1,19 @@ | |||
# | |||
# Copyright (C) Stanislaw Adaszewski, 2020 | |||
# License: GPLv3 | |||
# | |||
import torch | |||
import numpy as np | |||
def init_glorot(in_channels, out_channels, dtype=torch.float32): | |||
"""Create a weight variable with Glorot & Bengio (AISTATS 2010) | |||
initialization. | |||
""" | |||
init_range = np.sqrt(6.0 / (in_channels + out_channels)) | |||
initial = -init_range + 2 * init_range * \ | |||
torch.rand(( in_channels, out_channels ), dtype=dtype) | |||
initial = initial.requires_grad_(True) | |||
return initial |