IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_split_edge_type_01/02/03/04().

master
Stanislaw Adaszewski 3 years ago
parent
commit
6d27e26fa7
2 changed files with 92 additions and 1 deletions
  1. +8
    -0
      src/triacontagon/split.py
  2. +84
    -1
      tests/triacontagon/test_split.py

+ 8
- 0
src/triacontagon/split.py View File

@@ -7,6 +7,10 @@ import torch
def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
ratios = list(ratios)
if sum(ratios) != 1:
raise ValueError('Sum of ratios must be 1')
indices = adj_mat.indices()
values = adj_mat.values()
@@ -33,6 +37,10 @@ def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
ratios = list(ratios)
if sum(ratios) != 1:
raise ValueError('Sum of ratios must be 1')
res = [ split_adj_mat(adj_mat, ratios) \
for adj_mat in et.adjacency_matrices ]


+ 84
- 1
tests/triacontagon/test_split.py View File

@@ -1,5 +1,7 @@
from triacontagon.split import split_adj_mat
from triacontagon.split import split_adj_mat, \
split_edge_type
from triacontagon.util import _equal
from triacontagon.data import EdgeType
import torch
@@ -39,3 +41,84 @@ def test_split_adj_mat_03():
print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense())
assert torch.all(_equal(a+b+c, adj_mat))
def test_split_edge_type_01():
et = EdgeType('Dummy', 0, 1, [
torch.tensor([
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[0, 1, 0, 1, 0]
]).to_sparse()
], None, None)
res = split_edge_type(et, (1.,))
assert torch.all(_equal(et.adjacency_matrices[0],
res[0].adjacency_matrices[0]))
def test_split_edge_type_02():
et = EdgeType('Dummy', 0, 1, [
torch.tensor([
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[0, 1, 0, 1, 0]
]).to_sparse()
], None, None)
res = split_edge_type(et, (.5, .5))
assert torch.all(_equal(et.adjacency_matrices[0],
res[0].adjacency_matrices[0] + \
res[1].adjacency_matrices[0]))
def test_split_edge_type_03():
et = EdgeType('Dummy', 0, 1, [
torch.tensor([
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[0, 1, 0, 1, 0]
]).to_sparse()
], None, None)
res = split_edge_type(et, (.4, .4, .2))
assert torch.all(_equal(et.adjacency_matrices[0],
res[0].adjacency_matrices[0] + \
res[1].adjacency_matrices[0] + \
res[2].adjacency_matrices[0]))
def test_split_edge_type_04():
et = EdgeType('Dummy', 0, 1, [
torch.tensor([
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[1, 0, 0, 0, 1],
[0, 1, 0, 1, 0]
]).to_sparse(),
torch.tensor([
[1, 0, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 1, 1, 0],
[1, 0, 1, 0, 0]
]).to_sparse()
], None, None)
res = split_edge_type(et, (.4, .4, .2))
assert torch.all(_equal(et.adjacency_matrices[0],
res[0].adjacency_matrices[0] + \
res[1].adjacency_matrices[0] + \
res[2].adjacency_matrices[0]))
assert torch.all(_equal(et.adjacency_matrices[1],
res[0].adjacency_matrices[1] + \
res[1].adjacency_matrices[1] + \
res[2].adjacency_matrices[1]))

Loading…
Cancel
Save