|
|
@@ -1,7 +1,10 @@ |
|
|
|
from triacontagon.split import split_adj_mat, \
|
|
|
|
split_edge_type
|
|
|
|
split_edge_type, \
|
|
|
|
split_data
|
|
|
|
from triacontagon.util import _equal
|
|
|
|
from triacontagon.data import EdgeType
|
|
|
|
from triacontagon.data import EdgeType, \
|
|
|
|
Data
|
|
|
|
from triacontagon.decode import dedicom_decoder
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
@@ -122,3 +125,125 @@ def test_split_edge_type_04(): |
|
|
|
res[0].adjacency_matrices[1] + \
|
|
|
|
res[1].adjacency_matrices[1] + \
|
|
|
|
res[2].adjacency_matrices[1]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_split_data_01():
|
|
|
|
data = Data()
|
|
|
|
data.add_vertex_type('Foo', 5)
|
|
|
|
data.add_vertex_type('Bar', 4)
|
|
|
|
|
|
|
|
foo_foo = torch.tensor([
|
|
|
|
[0, 1, 0, 1, 0],
|
|
|
|
[0, 0, 0, 1, 0],
|
|
|
|
[0, 1, 0, 0, 1],
|
|
|
|
[0, 1, 0, 0, 0],
|
|
|
|
[1, 0, 0, 1, 0]
|
|
|
|
], dtype=torch.float32)
|
|
|
|
foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
|
|
|
|
|
|
|
|
foo_bar = torch.tensor([
|
|
|
|
[0, 1, 0, 1],
|
|
|
|
[0, 0, 0, 1],
|
|
|
|
[0, 1, 0, 0],
|
|
|
|
[1, 0, 0, 0],
|
|
|
|
[0, 0, 1, 1]
|
|
|
|
], dtype=torch.float32)
|
|
|
|
bar_foo = foo_bar.transpose(0, 1)
|
|
|
|
|
|
|
|
bar_bar = torch.tensor([
|
|
|
|
[0, 0, 1, 0],
|
|
|
|
[1, 0, 0, 0],
|
|
|
|
[0, 1, 0, 1],
|
|
|
|
[0, 1, 0, 0],
|
|
|
|
], dtype=torch.float32)
|
|
|
|
bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
|
|
|
|
|
|
|
|
data.add_edge_type('Foo-Foo', 0, 0, [
|
|
|
|
foo_foo.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Foo-Bar', 0, 1, [
|
|
|
|
foo_bar.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Bar-Foo', 1, 0, [
|
|
|
|
bar_foo.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Bar-Bar', 1, 1, [
|
|
|
|
bar_bar.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
|
|
|
|
(res,) = split_data(data, (1.,))
|
|
|
|
|
|
|
|
assert torch.all(_equal(res.edge_types[0, 0].adjacency_matrices[0],
|
|
|
|
data.edge_types[0, 0].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(res.edge_types[0, 1].adjacency_matrices[0],
|
|
|
|
data.edge_types[0, 1].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(res.edge_types[1, 0].adjacency_matrices[0],
|
|
|
|
data.edge_types[1, 0].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(res.edge_types[1, 1].adjacency_matrices[0],
|
|
|
|
data.edge_types[1, 1].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_split_data_02():
|
|
|
|
data = Data()
|
|
|
|
data.add_vertex_type('Foo', 5)
|
|
|
|
data.add_vertex_type('Bar', 4)
|
|
|
|
|
|
|
|
foo_foo = torch.tensor([
|
|
|
|
[0, 1, 0, 1, 0],
|
|
|
|
[0, 0, 0, 1, 0],
|
|
|
|
[0, 1, 0, 0, 1],
|
|
|
|
[0, 1, 0, 0, 0],
|
|
|
|
[1, 0, 0, 1, 0]
|
|
|
|
], dtype=torch.float32)
|
|
|
|
foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
|
|
|
|
|
|
|
|
foo_bar = torch.tensor([
|
|
|
|
[0, 1, 0, 1],
|
|
|
|
[0, 0, 0, 1],
|
|
|
|
[0, 1, 0, 0],
|
|
|
|
[1, 0, 0, 0],
|
|
|
|
[0, 0, 1, 1]
|
|
|
|
], dtype=torch.float32)
|
|
|
|
bar_foo = foo_bar.transpose(0, 1)
|
|
|
|
|
|
|
|
bar_bar = torch.tensor([
|
|
|
|
[0, 0, 1, 0],
|
|
|
|
[1, 0, 0, 0],
|
|
|
|
[0, 1, 0, 1],
|
|
|
|
[0, 1, 0, 0],
|
|
|
|
], dtype=torch.float32)
|
|
|
|
bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
|
|
|
|
|
|
|
|
data.add_edge_type('Foo-Foo', 0, 0, [
|
|
|
|
foo_foo.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Foo-Bar', 0, 1, [
|
|
|
|
foo_bar.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Bar-Foo', 1, 0, [
|
|
|
|
bar_foo.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
data.add_edge_type('Bar-Bar', 1, 1, [
|
|
|
|
bar_bar.to_sparse().coalesce()
|
|
|
|
], dedicom_decoder)
|
|
|
|
|
|
|
|
a, b = split_data(data, (.5,.5))
|
|
|
|
|
|
|
|
assert torch.all(_equal(a.edge_types[0, 0].adjacency_matrices[0] + \
|
|
|
|
b.edge_types[0, 0].adjacency_matrices[0],
|
|
|
|
data.edge_types[0, 0].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(a.edge_types[0, 1].adjacency_matrices[0] + \
|
|
|
|
b.edge_types[0, 1].adjacency_matrices[0],
|
|
|
|
data.edge_types[0, 1].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(a.edge_types[1, 0].adjacency_matrices[0] + \
|
|
|
|
b.edge_types[1, 0].adjacency_matrices[0],
|
|
|
|
data.edge_types[1, 0].adjacency_matrices[0]))
|
|
|
|
|
|
|
|
assert torch.all(_equal(a.edge_types[1, 1].adjacency_matrices[0] + \
|
|
|
|
b.edge_types[1, 1].adjacency_matrices[0],
|
|
|
|
data.edge_types[1, 1].adjacency_matrices[0]))
|