| @@ -5,14 +5,12 @@ | |||||
| import torch | import torch | ||||
| from .dropout import dropout_sparse, \ | |||||
| dropout_dense | |||||
| from .dropout import dropout | |||||
| from .weights import init_glorot | from .weights import init_glorot | ||||
| from typing import List, Callable | from typing import List, Callable | ||||
| class GraphConv(torch.nn.Module): | class GraphConv(torch.nn.Module): | ||||
| """Convolution layer for sparse AND dense inputs.""" | |||||
| def __init__(self, in_channels: int, out_channels: int, | def __init__(self, in_channels: int, out_channels: int, | ||||
| adjacency_matrix: torch.Tensor, **kwargs) -> None: | adjacency_matrix: torch.Tensor, **kwargs) -> None: | ||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| @@ -46,40 +44,7 @@ class DropoutGraphConvActivation(torch.nn.Module): | |||||
| self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | ||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||
| x = dropout_sparse(x, self.keep_prob) \ | |||||
| if x.is_sparse \ | |||||
| else dropout_dense(x, self.keep_prob) | |||||
| x = dropout(x, self.keep_prob) | |||||
| x = self.graph_conv(x) | x = self.graph_conv(x) | ||||
| x = self.activation(x) | x = self.activation(x) | ||||
| return x | return x | ||||
| class MultiDGCA(torch.nn.Module): | |||||
| def __init__(self, input_dim: List[int], output_dim: int, | |||||
| adjacency_matrices: List[torch.Tensor], keep_prob: float=1., | |||||
| activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, | |||||
| **kwargs) -> None: | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.adjacency_matrices = adjacency_matrices | |||||
| self.keep_prob = keep_prob | |||||
| self.activation = activation | |||||
| self.dgca = None | |||||
| self.build() | |||||
| def build(self): | |||||
| if len(self.input_dim) != len(self.adjacency_matrices): | |||||
| raise ValueError('input_dim must have the same length as adjacency_matrices') | |||||
| self.dgca = [] | |||||
| for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): | |||||
| self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) | |||||
| def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: | |||||
| if not isinstance(x, list): | |||||
| raise ValueError('x must be a list of tensors') | |||||
| out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) | |||||
| for i, f in enumerate(self.dgca): | |||||
| out += f(x[i]) | |||||
| out = torch.nn.functional.normalize(out, p=2, dim=1) | |||||
| return out | |||||
| @@ -0,0 +1,131 @@ | |||||
| # | |||||
| # Copyright (C) Stanislaw Adaszewski, 2020 | |||||
| # License: GPLv3 | |||||
| # | |||||
| import torch | |||||
| from .weights import init_glorot | |||||
| from .dropout import dropout | |||||
| class DEDICOMDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.drop_prob = drop_prob | |||||
| self.activation = activation | |||||
| self.global_interaction = init_glorot(input_dim, input_dim) | |||||
| self.local_variation = [ | |||||
| torch.flatten(init_glorot(input_dim, 1)) \ | |||||
| for _ in range(num_relation_types) | |||||
| ] | |||||
| def forward(self, inputs_row, inputs_col): | |||||
| outputs = [] | |||||
| for k in range(self.num_relation_types): | |||||
| inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
| inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
| relation = torch.diag(self.local_variation[k]) | |||||
| product1 = torch.mm(inputs_row, relation) | |||||
| product2 = torch.mm(product1, self.global_interaction) | |||||
| product3 = torch.mm(product2, relation) | |||||
| rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| outputs.append(self.activation(rec)) | |||||
| return outputs | |||||
| class DistMultDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.drop_prob = drop_prob | |||||
| self.activation = activation | |||||
| self.relation = [ | |||||
| torch.flatten(init_glorot(input_dim, 1)) \ | |||||
| for _ in range(num_relation_types) | |||||
| ] | |||||
| def forward(self, inputs_row, inputs_col): | |||||
| outputs = [] | |||||
| for k in range(self.num_relation_types): | |||||
| inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
| inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
| relation = torch.diag(self.relation[k]) | |||||
| intermediate_product = torch.mm(inputs_row, relation) | |||||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| outputs.append(self.activation(rec)) | |||||
| return outputs | |||||
| class BilinearDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.drop_prob = drop_prob | |||||
| self.activation = activation | |||||
| self.relation = [ | |||||
| init_glorot(input_dim, input_dim) \ | |||||
| for _ in range(num_relation_types) | |||||
| ] | |||||
| def forward(self, inputs_row, inputs_col): | |||||
| outputs = [] | |||||
| for k in range(self.num_relation_types): | |||||
| inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
| inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
| intermediate_product = torch.mm(inputs_row, self.relation[k]) | |||||
| rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| outputs.append(self.activation(rec)) | |||||
| return outputs | |||||
| class InnerProductDecoder(torch.nn.Module): | |||||
| """DEDICOM Tensor Factorization Decoder model layer for link prediction.""" | |||||
| def __init__(self, input_dim, num_relation_types, drop_prob=0., | |||||
| activation=torch.sigmoid, **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| self.input_dim = input_dim | |||||
| self.num_relation_types = num_relation_types | |||||
| self.drop_prob = drop_prob | |||||
| self.activation = activation | |||||
| def forward(self, inputs_row, inputs_col): | |||||
| outputs = [] | |||||
| for k in range(self.num_relation_types): | |||||
| inputs_row = dropout(inputs_row, 1.-self.drop_prob) | |||||
| inputs_col = dropout(inputs_col, 1.-self.drop_prob) | |||||
| rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), | |||||
| inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) | |||||
| rec = torch.flatten(rec) | |||||
| outputs.append(self.activation(rec)) | |||||
| return outputs | |||||
| @@ -31,3 +31,10 @@ def dropout_dense(x, keep_prob): | |||||
| x[i[n, 0], i[n, 1]] = 0. | x[i[n, 0], i[n, 1]] = 0. | ||||
| return x * (1./keep_prob) | return x * (1./keep_prob) | ||||
| def dropout(x, keep_prob): | |||||
| if x.is_sparse: | |||||
| return dropout_sparse(x, keep_prob) | |||||
| else: | |||||
| return dropout_dense(x, keep_prob) | |||||
| @@ -0,0 +1,94 @@ | |||||
| import torch | |||||
| from .convolve import DropoutGraphConvActivation | |||||
| from .data import Data | |||||
| from .trainprep import PreparedData | |||||
| from typing import List, \ | |||||
| Union, \ | |||||
| Callable | |||||
| from collections import defaultdict | |||||
| from dataclasses import dataclass | |||||
| @dataclass | |||||
| class Convolutions(object): | |||||
| node_type_column: int | |||||
| convolutions: List[DropoutGraphConvActivation] | |||||
| class DecagonLayer(torch.nn.Module): | |||||
| def __init__(self, | |||||
| input_dim: List[int], | |||||
| output_dim: List[int], | |||||
| data: Union[Data, PreparedData], | |||||
| keep_prob: float = 1., | |||||
| rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, | |||||
| layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu, | |||||
| **kwargs): | |||||
| super().__init__(**kwargs) | |||||
| if not isinstance(input_dim, list): | |||||
| raise ValueError('input_dim must be a list') | |||||
| if not isinstance(output_dim, list): | |||||
| raise ValueError('output_dim must be a list') | |||||
| if not isinstance(data, Data) and not isinstance(data, PreparedData): | |||||
| raise ValueError('data must be of type Data or PreparedData') | |||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.data = data | |||||
| self.keep_prob = float(keep_prob) | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.is_sparse = False | |||||
| self.next_layer_repr = None | |||||
| self.build() | |||||
| def build(self): | |||||
| n = len(self.data.node_types) | |||||
| rel_types = self.data.relation_types | |||||
| self.next_layer_repr = [ [] for _ in range(n) ] | |||||
| for node_type_row in range(n): | |||||
| if node_type_row not in rel_types: | |||||
| continue | |||||
| for node_type_column in range(n): | |||||
| if node_type_column not in rel_types[node_type_row]: | |||||
| continue | |||||
| rels = rel_types[node_type_row][node_type_column] | |||||
| if len(rels) == 0: | |||||
| continue | |||||
| convolutions = [] | |||||
| for r in rels: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||||
| self.keep_prob, self.rel_activation) | |||||
| convolutions.append(conv) | |||||
| self.next_layer_repr[node_type_row].append( | |||||
| Convolutions(node_type_column, convolutions)) | |||||
| def __call__(self, prev_layer_repr): | |||||
| next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||||
| n = len(self.data.node_types) | |||||
| for node_type_row in range(n): | |||||
| for convolutions in self.next_layer_repr[node_type_row]: | |||||
| repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \ | |||||
| for conv in convolutions.convolutions ] | |||||
| repr_ = sum(repr_) | |||||
| repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1) | |||||
| next_layer_repr[i].append(repr_) | |||||
| next_layer_repr[i] = sum(next_layer_repr[i]) | |||||
| next_layer_repr[i] = self.layer_activation(next_layer_repr[i]) | |||||
| return next_layer_repr | |||||
| @@ -38,7 +38,7 @@ class PreparedRelationType(object): | |||||
| name: str | name: str | ||||
| node_type_row: int | node_type_row: int | ||||
| node_type_column: int | node_type_column: int | ||||
| adj_mat_train: torch.Tensor | |||||
| adjacency_matrix: torch.Tensor | |||||
| edges_pos: TrainValTest | edges_pos: TrainValTest | ||||
| edges_neg: TrainValTest | edges_neg: TrainValTest | ||||
| @@ -1,7 +1,7 @@ | |||||
| from icosagon.convolve import GraphConv, \ | from icosagon.convolve import GraphConv, \ | ||||
| DropoutGraphConvActivation, \ | |||||
| MultiDGCA | |||||
| DropoutGraphConvActivation | |||||
| import torch | import torch | ||||
| from icosagon.dropout import dropout | |||||
| def _test_graph_conv_01(use_sparse: bool): | def _test_graph_conv_01(use_sparse: bool): | ||||
| @@ -92,3 +92,99 @@ def test_graph_conv_sparse_02(): | |||||
| def test_graph_conv_sparse_03(): | def test_graph_conv_sparse_03(): | ||||
| _test_graph_conv_03(use_sparse=True) | _test_graph_conv_03(use_sparse=True) | ||||
| def _test_dropout_graph_conv_activation_01(use_sparse: bool): | |||||
| adj_mat = torch.rand((10, 20)) | |||||
| adj_mat[adj_mat < .5] = 0 | |||||
| adj_mat = torch.ceil(adj_mat) | |||||
| node_reprs = torch.eye(20) | |||||
| conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat, keep_prob=1., | |||||
| activation=lambda x: x) | |||||
| conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat) | |||||
| conv_2.weight = conv_1.graph_conv.weight | |||||
| res_1 = conv_1(node_reprs) | |||||
| res_2 = conv_2(node_reprs) | |||||
| print('res_1:', res_1.detach().cpu().numpy()) | |||||
| print('res_2:', res_2.detach().cpu().numpy()) | |||||
| assert torch.all(res_1 == res_2) | |||||
| def _test_dropout_graph_conv_activation_02(use_sparse: bool): | |||||
| adj_mat = torch.rand((10, 20)) | |||||
| adj_mat[adj_mat < .5] = 0 | |||||
| adj_mat = torch.ceil(adj_mat) | |||||
| node_reprs = torch.eye(20) | |||||
| conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat, keep_prob=1., | |||||
| activation=lambda x: x * 2) | |||||
| conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat) | |||||
| conv_2.weight = conv_1.graph_conv.weight | |||||
| res_1 = conv_1(node_reprs) | |||||
| res_2 = conv_2(node_reprs) | |||||
| print('res_1:', res_1.detach().cpu().numpy()) | |||||
| print('res_2:', res_2.detach().cpu().numpy()) | |||||
| assert torch.all(res_1 == res_2 * 2) | |||||
| def _test_dropout_graph_conv_activation_03(use_sparse: bool): | |||||
| adj_mat = torch.rand((10, 20)) | |||||
| adj_mat[adj_mat < .5] = 0 | |||||
| adj_mat = torch.ceil(adj_mat) | |||||
| node_reprs = torch.eye(20) | |||||
| conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat, keep_prob=.5, | |||||
| activation=lambda x: x) | |||||
| conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ | |||||
| if use_sparse else adj_mat) | |||||
| conv_2.weight = conv_1.graph_conv.weight | |||||
| torch.random.manual_seed(0) | |||||
| res_1 = conv_1(node_reprs) | |||||
| torch.random.manual_seed(0) | |||||
| res_2 = conv_2(dropout(node_reprs, 0.5)) | |||||
| print('res_1:', res_1.detach().cpu().numpy()) | |||||
| print('res_2:', res_2.detach().cpu().numpy()) | |||||
| assert torch.all(res_1 == res_2) | |||||
| def test_dropout_graph_conv_activation_dense_01(): | |||||
| _test_dropout_graph_conv_activation_01(False) | |||||
| def test_dropout_graph_conv_activation_sparse_01(): | |||||
| _test_dropout_graph_conv_activation_01(True) | |||||
| def test_dropout_graph_conv_activation_dense_02(): | |||||
| _test_dropout_graph_conv_activation_02(False) | |||||
| def test_dropout_graph_conv_activation_sparse_02(): | |||||
| _test_dropout_graph_conv_activation_02(True) | |||||
| def test_dropout_graph_conv_activation_dense_03(): | |||||
| _test_dropout_graph_conv_activation_03(False) | |||||
| def test_dropout_graph_conv_activation_sparse_03(): | |||||
| _test_dropout_graph_conv_activation_03(True) | |||||