|
|
@@ -0,0 +1,66 @@ |
|
|
|
from triacontagon.loop import _merge_pos_neg_batches
|
|
|
|
from triacontagon.model import TrainingBatch
|
|
|
|
import torch
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
def test_merge_pos_neg_batches_01():
|
|
|
|
b_1 = TrainingBatch(0, 0, 0, torch.tensor([
|
|
|
|
[0, 1],
|
|
|
|
[2, 3],
|
|
|
|
[4, 5],
|
|
|
|
[5, 6]
|
|
|
|
]), torch.ones(4))
|
|
|
|
b_2 = TrainingBatch(0, 0, 0, torch.tensor([
|
|
|
|
[1, 6],
|
|
|
|
[3, 5],
|
|
|
|
[5, 2],
|
|
|
|
[4, 1]
|
|
|
|
]), torch.zeros(4))
|
|
|
|
b = _merge_pos_neg_batches(b_1, b_2)
|
|
|
|
|
|
|
|
assert b.vertex_type_row == 0
|
|
|
|
assert b.vertex_type_column == 0
|
|
|
|
assert b.relation_type_index == 0
|
|
|
|
assert torch.all(b.edges == torch.tensor([
|
|
|
|
[0, 1],
|
|
|
|
[2, 3],
|
|
|
|
[4, 5],
|
|
|
|
[5, 6],
|
|
|
|
[1, 6],
|
|
|
|
[3, 5],
|
|
|
|
[5, 2],
|
|
|
|
[4, 1]
|
|
|
|
]))
|
|
|
|
assert torch.all(b.target_values == \
|
|
|
|
torch.cat([ torch.ones(4), torch.zeros(4) ]))
|
|
|
|
|
|
|
|
|
|
|
|
def test_merge_pos_neg_batches_02():
|
|
|
|
b_1 = TrainingBatch(0, 1, 0, torch.tensor([
|
|
|
|
[0, 1],
|
|
|
|
[2, 3],
|
|
|
|
[4, 5],
|
|
|
|
[5, 6]
|
|
|
|
]), torch.ones(4))
|
|
|
|
b_2 = TrainingBatch(0, 0, 0, torch.tensor([
|
|
|
|
[1, 6],
|
|
|
|
[3, 5],
|
|
|
|
[5, 2],
|
|
|
|
[4, 1]
|
|
|
|
]), torch.zeros(4))
|
|
|
|
print(b_1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
_ = _merge_pos_neg_batches(b_1, b_2)
|
|
|
|
|
|
|
|
b_1.vertex_type_row, b_1.vertex_type_column = \
|
|
|
|
b_1.vertex_type_column, b_1.vertex_type_row
|
|
|
|
print(b_1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
_ = _merge_pos_neg_batches(b_1, b_2)
|
|
|
|
|
|
|
|
b_1.vertex_type_row, b_1.relation_type_index = \
|
|
|
|
b_1.relation_type_index, b_1.vertex_type_row
|
|
|
|
print(b_1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
_ = _merge_pos_neg_batches(b_1, b_2)
|