diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index 3fe1313..465f2a5 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -50,10 +50,8 @@ class DecagonLayer(torch.nn.Module): self.next_layer_repr = None self.build() - def build(self): - self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] - - for (node_type_row, node_type_column), rels in self.data.relation_types.items(): + def build_family(self, fam): + for (node_type_row, node_type_column), rels in fam.relation_types.items(): if len(rels) == 0: continue @@ -69,6 +67,11 @@ class DecagonLayer(torch.nn.Module): self.next_layer_repr[node_type_row].append( Convolutions(node_type_column, convolutions)) + def build(self): + self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] + for fam in self.data.relation_families: + self.build_family(fam) + def __call__(self, prev_layer_repr): next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] n = len(self.data.node_types) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index cd55b29..b04432e 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -16,6 +16,24 @@ from .decode import DEDICOMDecoder, \ BilinearDecoder +def _equal(x: torch.Tensor, y: torch.Tensor): + if x.is_sparse ^ y.is_sparse: + raise ValueError('Cannot mix sparse and dense tensors') + + if not x.is_sparse: + return (x == y) + + x = x.coalesce() + indices_x = list(map(tuple, x.indices().transpose(0, 1))) + order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx]) + + y = y.coalesce() + indices_y = list(map(tuple, y.indices().transpose(0, 1))) + order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx]) + + return (x.values()[order_x] == y.values()[order_y]) + + @dataclass class NodeType(object): name: str @@ -97,7 +115,8 @@ class RelationFamily(object): raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family') if self.is_symmetric and node_type_row == node_type_column and \ - not torch.all(adjacency_matrix == adjacency_matrix.transpose(0, 1)): + not torch.all(_equal(adjacency_matrix, + adjacency_matrix.transpose(0, 1))): raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') two_way = bool(two_way) diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index 8e713c2..cc0458f 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -10,20 +10,36 @@ from decagon_pytorch.convolve import MultiDGCA import decagon_pytorch.convolve +def _make_symmetric(x: torch.Tensor): + x = (x + x.transpose(0, 1)) / 2 + return x + + +def _symmetric_random(n_rows, n_columns): + return _make_symmetric(torch.rand((n_rows, n_columns), + dtype=torch.float32).round()) + + def _some_data_with_interactions(): d = Data() d.add_node_type('Gene', 1000) d.add_node_type('Drug', 100) - d.add_relation_type('Target', 1, 0, + + fam = d.add_relation_family('Drug-Gene', 1, 0, True) + fam.add_relation_type('Target', 1, 0, torch.rand((100, 1000), dtype=torch.float32).round()) - d.add_relation_type('Interaction', 0, 0, - torch.rand((1000, 1000), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Nausea', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Infertility', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) - d.add_relation_type('Side Effect: Death', 1, 1, - torch.rand((100, 100), dtype=torch.float32).round()) + + fam = d.add_relation_family('Gene-Gene', 0, 0, True) + fam.add_relation_type('Interaction', 0, 0, + _symmetric_random(1000, 1000)) + + fam = d.add_relation_family('Drug-Drug', 1, 1, True) + fam.add_relation_type('Side Effect: Nausea', 1, 1, + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Infertility', 1, 1, + _symmetric_random(100, 100)) + fam.add_relation_type('Side Effect: Death', 1, 1, + _symmetric_random(100, 100)) return d @@ -80,13 +96,14 @@ def test_decagon_layer_04(): d = Data() d.add_node_type('Dummy', 100) - d.add_relation_type('Dummy Relation', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) + fam.add_relation_type('Dummy Relation', 0, 0, + _symmetric_random(100, 100).to_sparse()) in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([10], 32, - [r.adjacency_matrix for r in d.relation_types[0, 0]], + [r.adjacency_matrix for r in fam.relation_types[0, 0]], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, 32, d, @@ -121,15 +138,16 @@ def test_decagon_layer_05(): d = Data() d.add_node_type('Dummy', 100) - d.add_relation_type('Dummy Relation 1', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) - d.add_relation_type('Dummy Relation 2', 0, 0, - torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, True) + fam.add_relation_type('Dummy Relation 1', 0, 0, + _symmetric_random(100, 100).to_sparse()) + fam.add_relation_type('Dummy Relation 2', 0, 0, + _symmetric_random(100, 100).to_sparse()) in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([100, 100], 32, - [r.adjacency_matrix for r in d.relation_types[0, 0]], + [r.adjacency_matrix for r in fam.relation_types[0, 0]], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,