IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add tests for _same_data_org() and DualBatcher.

master
Stanislaw Adaszewski 3 years ago
parent
commit
c206df638a
2 changed files with 121 additions and 4 deletions
  1. +3
    -3
      src/triacontagon/batch.py
  2. +118
    -1
      tests/triacontagon/test_batch.py

+ 3
- 3
src/triacontagon/batch.py View File

@@ -38,9 +38,9 @@ def _same_data_org(pos_data: Data, neg_data: Data):
test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
for k in range(len(pos_data.edge_types[i])) ] \
for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \
for i in pos_data.edge_types.keys() ]
test = reduce(list.__add__, test)
test = reduce(list.__add__, test, [])
if not all(test):
return False
@@ -112,7 +112,7 @@ class DualBatcher(object):
offsets[edge_idx][rel_idx] += self.batch_size
res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
rel_idx, lst, torch.full(len(lst), target_value,
rel_idx, lst, torch.full(( len(lst), ), target_value,
dtype=torch.float32))
return res


+ 118
- 1
tests/triacontagon/test_batch.py View File

@@ -1,9 +1,59 @@
from triacontagon.batch import Batcher
from triacontagon.batch import _same_data_org, \
DualBatcher, \
Batcher
from triacontagon.data import Data
from triacontagon.decode import dedicom_decoder
import torch
def test_same_data_org_01():
data = Data()
assert _same_data_org(data, data)
data.add_vertex_type('Foo', 10)
assert _same_data_org(data, data)
data.add_vertex_type('Bar', 10)
assert _same_data_org(data, data)
data_1 = Data()
assert not _same_data_org(data, data_1)
data_1.add_vertex_type('Foo', 10)
assert not _same_data_org(data, data_1)
data_1.add_vertex_type('Bar', 10)
assert _same_data_org(data, data_1)
def test_same_data_org_02():
data = Data()
data.add_vertex_type('Foo', 4)
data.add_edge_type('Foo-Foo', 0, 0, [
torch.tensor([
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 1, 0],
[1, 0, 1, 0]
]).to_sparse()
], dedicom_decoder)
assert _same_data_org(data, data)
data_1 = Data()
data_1.add_vertex_type('Foo', 4)
data_1.add_edge_type('Foo-Foo', 0, 0, [
torch.tensor([
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, 1, 1, 0],
[1, 0, 0, 0]
]).to_sparse()
], dedicom_decoder)
assert not _same_data_org(data, data_1)
def test_batcher_01():
d = Data()
d.add_vertex_type('Gene', 5)
@@ -197,3 +247,70 @@ def test_batcher_05():
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
(0, 1, 0, 4, 2) }
def test_dual_batcher_01():
d = Data()
d.add_vertex_type('Gene', 5)
d.add_vertex_type('Drug', 4)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]
]).to_sparse(),
torch.tensor([
[1, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]
]).to_sparse()
], dedicom_decoder)
d.add_edge_type('Gene-Drug', 0, 1, [
torch.tensor([
[0, 1, 0, 0],
[1, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 1, 1, 0]
]).to_sparse()
], dedicom_decoder)
b = DualBatcher(d, d, batch_size=5)
visited_pos = set()
visited_neg = set()
for t_pos, t_neg in b:
assert t_pos.vertex_type_row == t_neg.vertex_type_row
assert t_pos.vertex_type_column == t_neg.vertex_type_column
assert t_pos.relation_type_index == t_neg.relation_type_index
assert len(t_pos.edges) == len(t_neg.edges)
for e in t_pos.edges:
k = (t_pos.vertex_type_row, t_pos.vertex_type_column,
t_pos.relation_type_index,) + \
tuple(e.tolist())
visited_pos.add(k)
for e in t_neg.edges:
k = (t_neg.vertex_type_row, t_neg.vertex_type_column,
t_neg.relation_type_index,) + \
tuple(e.tolist())
visited_neg.add(k)
expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
(0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
(0, 0, 1, 0, 0), (0, 0, 1, 0, 2), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
(0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
(0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
(0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
(0, 1, 0, 4, 2) }
assert visited_pos == expected
assert visited_neg == expected

Loading…
Cancel
Save