From 28272e3d296e86570307f90251ef25417e62a6f5 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 5 Jun 2020 14:46:32 +0200 Subject: [PATCH] Add AdjListRelationType, AdjListData. --- src/decagon_pytorch/data/__init__.py | 2 + src/decagon_pytorch/data/list.py | 68 +++++++++++++++++++ .../{data.py => data/matrix.py} | 5 +- tests/decagon_pytorch/test_data_list.py | 67 ++++++++++++++++++ .../{test_data.py => test_data_matrix.py} | 0 5 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 src/decagon_pytorch/data/__init__.py create mode 100644 src/decagon_pytorch/data/list.py rename src/decagon_pytorch/{data.py => data/matrix.py} (92%) create mode 100644 tests/decagon_pytorch/test_data_list.py rename tests/decagon_pytorch/{test_data.py => test_data_matrix.py} (100%) diff --git a/src/decagon_pytorch/data/__init__.py b/src/decagon_pytorch/data/__init__.py new file mode 100644 index 0000000..5820dbb --- /dev/null +++ b/src/decagon_pytorch/data/__init__.py @@ -0,0 +1,2 @@ +from .matrix import * +from .list import * diff --git a/src/decagon_pytorch/data/list.py b/src/decagon_pytorch/data/list.py new file mode 100644 index 0000000..ca022cb --- /dev/null +++ b/src/decagon_pytorch/data/list.py @@ -0,0 +1,68 @@ +from .matrix import NodeType +import torch +from collections import defaultdict + + +class AdjListRelationType(object): + def __init__(self, name, node_type_row, node_type_column, + adjacency_list, adjacency_list_transposed=None): + + #if adjacency_matrix_transposed is not None and \ + # adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape: + # raise ValueError('adjacency_matrix_transposed has incorrect shape') + + self.name = name + self.node_type_row = node_type_row + self.node_type_column = node_type_column + self.adjacency_list = adjacency_list + self.adjacency_list_transposed = adjacency_list_transposed + + def get_adjacency_list(self, node_type_row, node_type_column): + if self.node_type_row == node_type_row and \ + self.node_type_column == node_type_column: + return self.adjacency_list + + elif self.node_type_row == node_type_column and \ + self.node_type_column == node_type_row: + if self.adjacency_list_transposed is not None: + return self.adjacency_list_transposed + else: + return torch.index_select(self.adjacency_list, 1, + torch.LongTensor([1, 0])) + + else: + raise ValueError('Specified row/column types do not correspond to this relation') + + +def _verify_adjacency_list(adjacency_list, node_count_row, node_count_col): + assert isinstance(adjacency_list, torch.Tensor) + assert len(adjacency_list.shape) == 2 + assert torch.all(adjacency_list[:, 0] >= 0) + assert torch.all(adjacency_list[:, 0] < node_count_row) + assert torch.all(adjacency_list[:, 1] >= 0) + assert torch.all(adjacency_list[:, 1] < node_count_col) + + +class AdjListData(object): + def __init__(self): + self.node_types = [] + self.relation_types = defaultdict(list) + + def add_node_type(self, name, count): # , latent_length): + self.node_types.append(NodeType(name, count)) + + def add_relation_type(self, name, node_type_row, node_type_col, adjacency_list, adjacency_list_transposed=None): + assert node_type_row >= 0 and node_type_row < len(self.node_types) + assert node_type_col >= 0 and node_type_col < len(self.node_types) + + node_count_row = self.node_types[node_type_row].count + node_count_col = self.node_types[node_type_col].count + + _verify_adjacency_list(adjacency_list, node_count_row, node_count_col) + if adjacency_list_transposed is not None: + _verify_adjacency_list(adjacency_list_transposed, + node_count_col, node_count_row) + + self.relation_types[node_type_row, node_type_col].append( + AdjListRelationType(name, node_type_row, node_type_col, + adjacency_list, adjacency_list_transposed)) diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data/matrix.py similarity index 92% rename from src/decagon_pytorch/data.py rename to src/decagon_pytorch/data/matrix.py index 762b1a6..cd4b110 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data/matrix.py @@ -5,7 +5,7 @@ from collections import defaultdict -from .weights import init_glorot +from ..weights import init_glorot class NodeType(object): @@ -18,7 +18,8 @@ class RelationType(object): def __init__(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed): - if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape: + if adjacency_matrix_transposed is not None and \ + adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape: raise ValueError('adjacency_matrix_transposed has incorrect shape') self.name = name diff --git a/tests/decagon_pytorch/test_data_list.py b/tests/decagon_pytorch/test_data_list.py new file mode 100644 index 0000000..a9cdb51 --- /dev/null +++ b/tests/decagon_pytorch/test_data_list.py @@ -0,0 +1,67 @@ +from decagon_pytorch.data import AdjListData, \ + AdjListRelationType +import torch +import pytest + + +def _get_list(): + lst = torch.tensor([ + [0, 1], + [0, 3], + [0, 5], + [0, 7] + ]) + return lst + + +def test_adj_list_relation_type_01(): + lst = _get_list() + rel = AdjListRelationType('Test', 0, 0, lst) + assert torch.all(rel.get_adjacency_list(0, 0) == lst) + + +def test_adj_list_relation_type_02(): + lst = _get_list() + rel = AdjListRelationType('Test', 0, 1, lst) + assert torch.all(rel.get_adjacency_list(0, 1) == lst) + lst_2 = torch.tensor([ + [1, 0], + [3, 0], + [5, 0], + [7, 0] + ]) + assert torch.all(rel.get_adjacency_list(1, 0) == lst_2) + + +def test_adj_list_relation_type_03(): + lst = _get_list() + lst_2 = torch.tensor([ + [2, 0], + [4, 0], + [6, 0], + [8, 0] + ]) + rel = AdjListRelationType('Test', 0, 1, lst, lst_2) + assert torch.all(rel.get_adjacency_list(0, 1) == lst) + assert torch.all(rel.get_adjacency_list(1, 0) == lst_2) + + +def test_adj_list_data_01(): + lst = _get_list() + d = AdjListData() + with pytest.raises(AssertionError): + d.add_relation_type('Test', 0, 1, lst) + d.add_node_type('Drugs', 5) + with pytest.raises(AssertionError): + d.add_relation_type('Test', 0, 0, lst) + d = AdjListData() + d.add_node_type('Drugs', 8) + d.add_relation_type('Test', 0, 0, lst) + + +def test_adj_list_data_02(): + lst = _get_list() + d = AdjListData() + d.add_node_type('Drugs', 10) + d.add_node_type('Proteins', 10) + d.add_relation_type('Target', 0, 1, lst) diff --git a/tests/decagon_pytorch/test_data.py b/tests/decagon_pytorch/test_data_matrix.py similarity index 100% rename from tests/decagon_pytorch/test_data.py rename to tests/decagon_pytorch/test_data_matrix.py