diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 81519fb..8838c9a 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -195,7 +195,7 @@ class FastConvLayer(torch.nn.Module): self.adjacency_matrix.append(adj_mat) self.adjacency_matrix_backward.append(adj_mat_back) self.weight.append(weight) - self.weight_back.append(weight_back) + self.weight_backward.append(weight_back) def forward(self, prev_layer_repr): for i, fam in enumerate(self.data.relation_families): @@ -210,16 +210,22 @@ class FastConvLayer(torch.nn.Module): x = torch.sparse.mm(x, self.weight[i]) \ if x.is_sparse \ else torch.mm(x, self.weight[i]) - x = torch.sparse.mm(adj_mat, repr_row) \ + x = torch.sparse.mm(adj_mat, x) \ if adj_mat.is_sparse \ - else torch.mm(adj_mat, repr_row) + else torch.mm(adj_mat, x) x = self.rel_activation(x) x = x.view(len(fam.relation_types), len(repr_row), -1) if adj_mat_back is not None: - x = torch.sparse.mm(adj_mat_back, repr_row) \ + x = dropout(repr_column, keep_prob=self.keep_prob) + x = torch.sparse.mm(x, self.weight_backward[i]) \ + if x.is_sparse \ + else torch.mm(x, self.weight_backward[i]) + x = torch.sparse.mm(adj_mat_back, x) \ if adj_mat_back.is_sparse \ - else torch.mm(adj_mat_back, repr_row) + else torch.mm(adj_mat_back, x) + x = self.rel_activation(x) + x = x.view(len(fam.relation_types), len(repr_row), -1) @staticmethod def _check_params(input_dim, output_dim, data, keep_prob, diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py index 407248a..742173d 100644 --- a/tests/icosagon/test_fastconv.py +++ b/tests/icosagon/test_fastconv.py @@ -1,10 +1,47 @@ from icosagon.fastconv import _sparse_diag_cat, \ _cat, \ - FastGraphConv + FastGraphConv, \ + FastConvLayer from icosagon.data import _equal import torch import pdb import time +from icosagon.data import Data +from icosagon.input import OneHotInputLayer +from icosagon.convlayer import DecagonLayer + + +def _make_symmetric(x: torch.Tensor): + x = (x + x.transpose(0, 1)) / 2 + return x + + +def _symmetric_random(n_rows, n_columns): + return _make_symmetric(torch.rand((n_rows, n_columns), + dtype=torch.float32).round()) + + +def _some_data_with_interactions(): + d = Data() + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + + fam = d.add_relation_family('Drug-Gene', 1, 0, True) + fam.add_relation_type('Target', + torch.rand((100, 1000), dtype=torch.float32).round()) + + fam = d.add_relation_family('Gene-Gene', 0, 0, True) + fam.add_relation_type('Interaction', + _symmetric_random(1000, 1000)) + + fam = d.add_relation_family('Drug-Drug', 1, 1, True) + fam.add_relation_type('Side Effect: Nausea', + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Infertility', + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Death', + _symmetric_random(100, 100)) + return d def test_sparse_diag_cat_01(): @@ -86,3 +123,51 @@ def test_fast_graph_conv_02(): t = time.time() _ = fgc(in_repr) print('FGC forward pass took:', time.time() - t) + + +def test_fast_graph_conv_03(): + adj_mat = [ + [ 0, 0, 1, 0, 1 ], + [ 0, 1, 0, 1, 0 ], + [ 1, 0, 1, 0, 0 ] + ] + in_repr = torch.rand(5, 32) + adj_mat = torch.tensor(adj_mat, dtype=torch.float32) + fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ]) + out_repr = fgc(in_repr) + assert out_repr.shape == (1, 3, 64) + assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all() + + +def test_fast_graph_conv_04(): + adj_mat = [ + [ 0, 0, 1, 0, 1 ], + [ 0, 1, 0, 1, 0 ], + [ 1, 0, 1, 0, 0 ] + ] + in_repr = torch.rand(5, 32) + adj_mat = torch.tensor(adj_mat, dtype=torch.float32) + fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ]) + out_repr = fgc(in_repr) + assert out_repr.shape == (2, 3, 64) + adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2) + adj_mat_1[0:3, 0:5] = adj_mat + adj_mat_1[3:6, 5:10] = adj_mat + res = torch.mm(in_repr, fgc.weights) + res = torch.split(res, res.shape[1] // 2, dim=1) + res = torch.cat(res) + res = torch.mm(adj_mat_1, res) + assert (res.view(2, 3, 64) == out_repr).all() + + +def test_fast_conv_layer_01(): + d = _some_data_with_interactions() + in_layer = OneHotInputLayer(d) + + d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d) + seq_1 = torch.nn.Sequential(in_layer, d_layer) + out_repr_1 = seq_1(None) + + conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d) + seq_2 = torch.nn.Sequential(in_layer, conv_layer) + out_repr_2 = seq_2(None)