IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед на файлове

Add AdjListRelationType, AdjListData.

master
Stanislaw Adaszewski преди 4 години
родител
ревизия
28272e3d29
променени са 5 файла, в които са добавени 140 реда и са изтрити 2 реда
  1. +2
    -0
      src/decagon_pytorch/data/__init__.py
  2. +68
    -0
      src/decagon_pytorch/data/list.py
  3. +3
    -2
      src/decagon_pytorch/data/matrix.py
  4. +67
    -0
      tests/decagon_pytorch/test_data_list.py
  5. +0
    -0
      tests/decagon_pytorch/test_data_matrix.py

+ 2
- 0
src/decagon_pytorch/data/__init__.py Целия файл

@@ -0,0 +1,2 @@
from .matrix import *
from .list import *

+ 68
- 0
src/decagon_pytorch/data/list.py Целия файл

@@ -0,0 +1,68 @@
from .matrix import NodeType
import torch
from collections import defaultdict
class AdjListRelationType(object):
def __init__(self, name, node_type_row, node_type_column,
adjacency_list, adjacency_list_transposed=None):
#if adjacency_matrix_transposed is not None and \
# adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
# raise ValueError('adjacency_matrix_transposed has incorrect shape')
self.name = name
self.node_type_row = node_type_row
self.node_type_column = node_type_column
self.adjacency_list = adjacency_list
self.adjacency_list_transposed = adjacency_list_transposed
def get_adjacency_list(self, node_type_row, node_type_column):
if self.node_type_row == node_type_row and \
self.node_type_column == node_type_column:
return self.adjacency_list
elif self.node_type_row == node_type_column and \
self.node_type_column == node_type_row:
if self.adjacency_list_transposed is not None:
return self.adjacency_list_transposed
else:
return torch.index_select(self.adjacency_list, 1,
torch.LongTensor([1, 0]))
else:
raise ValueError('Specified row/column types do not correspond to this relation')
def _verify_adjacency_list(adjacency_list, node_count_row, node_count_col):
assert isinstance(adjacency_list, torch.Tensor)
assert len(adjacency_list.shape) == 2
assert torch.all(adjacency_list[:, 0] >= 0)
assert torch.all(adjacency_list[:, 0] < node_count_row)
assert torch.all(adjacency_list[:, 1] >= 0)
assert torch.all(adjacency_list[:, 1] < node_count_col)
class AdjListData(object):
def __init__(self):
self.node_types = []
self.relation_types = defaultdict(list)
def add_node_type(self, name, count): # , latent_length):
self.node_types.append(NodeType(name, count))
def add_relation_type(self, name, node_type_row, node_type_col, adjacency_list, adjacency_list_transposed=None):
assert node_type_row >= 0 and node_type_row < len(self.node_types)
assert node_type_col >= 0 and node_type_col < len(self.node_types)
node_count_row = self.node_types[node_type_row].count
node_count_col = self.node_types[node_type_col].count
_verify_adjacency_list(adjacency_list, node_count_row, node_count_col)
if adjacency_list_transposed is not None:
_verify_adjacency_list(adjacency_list_transposed,
node_count_col, node_count_row)
self.relation_types[node_type_row, node_type_col].append(
AdjListRelationType(name, node_type_row, node_type_col,
adjacency_list, adjacency_list_transposed))

src/decagon_pytorch/data.py → src/decagon_pytorch/data/matrix.py Целия файл

@@ -5,7 +5,7 @@
from collections import defaultdict
from .weights import init_glorot
from ..weights import init_glorot
class NodeType(object):
@@ -18,7 +18,8 @@ class RelationType(object):
def __init__(self, name, node_type_row, node_type_column,
adjacency_matrix, adjacency_matrix_transposed):
if adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
if adjacency_matrix_transposed is not None and \
adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
raise ValueError('adjacency_matrix_transposed has incorrect shape')
self.name = name

+ 67
- 0
tests/decagon_pytorch/test_data_list.py Целия файл

@@ -0,0 +1,67 @@
from decagon_pytorch.data import AdjListData, \
AdjListRelationType
import torch
import pytest
def _get_list():
lst = torch.tensor([
[0, 1],
[0, 3],
[0, 5],
[0, 7]
])
return lst
def test_adj_list_relation_type_01():
lst = _get_list()
rel = AdjListRelationType('Test', 0, 0, lst)
assert torch.all(rel.get_adjacency_list(0, 0) == lst)
def test_adj_list_relation_type_02():
lst = _get_list()
rel = AdjListRelationType('Test', 0, 1, lst)
assert torch.all(rel.get_adjacency_list(0, 1) == lst)
lst_2 = torch.tensor([
[1, 0],
[3, 0],
[5, 0],
[7, 0]
])
assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
def test_adj_list_relation_type_03():
lst = _get_list()
lst_2 = torch.tensor([
[2, 0],
[4, 0],
[6, 0],
[8, 0]
])
rel = AdjListRelationType('Test', 0, 1, lst, lst_2)
assert torch.all(rel.get_adjacency_list(0, 1) == lst)
assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
def test_adj_list_data_01():
lst = _get_list()
d = AdjListData()
with pytest.raises(AssertionError):
d.add_relation_type('Test', 0, 1, lst)
d.add_node_type('Drugs', 5)
with pytest.raises(AssertionError):
d.add_relation_type('Test', 0, 0, lst)
d = AdjListData()
d.add_node_type('Drugs', 8)
d.add_relation_type('Test', 0, 0, lst)
def test_adj_list_data_02():
lst = _get_list()
d = AdjListData()
d.add_node_type('Drugs', 10)
d.add_node_type('Proteins', 10)
d.add_relation_type('Target', 0, 1, lst)

tests/decagon_pytorch/test_data.py → tests/decagon_pytorch/test_data_matrix.py Целия файл


Loading…
Отказ
Запис