|
- from triacontagon.loop import _merge_pos_neg_batches
- from triacontagon.model import TrainingBatch
- import torch
- import pytest
-
-
- def test_merge_pos_neg_batches_01():
- b_1 = TrainingBatch(0, 0, 0, torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6]
- ]), torch.ones(4))
- b_2 = TrainingBatch(0, 0, 0, torch.tensor([
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]), torch.zeros(4))
- b = _merge_pos_neg_batches(b_1, b_2)
-
- assert b.vertex_type_row == 0
- assert b.vertex_type_column == 0
- assert b.relation_type_index == 0
- assert torch.all(b.edges == torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6],
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]))
- assert torch.all(b.target_values == \
- torch.cat([ torch.ones(4), torch.zeros(4) ]))
-
-
- def test_merge_pos_neg_batches_02():
- b_1 = TrainingBatch(0, 1, 0, torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6]
- ]), torch.ones(4))
- b_2 = TrainingBatch(0, 0, 0, torch.tensor([
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]), torch.zeros(4))
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
-
- b_1.vertex_type_row, b_1.vertex_type_column = \
- b_1.vertex_type_column, b_1.vertex_type_row
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
-
- b_1.vertex_type_row, b_1.relation_type_index = \
- b_1.relation_type_index, b_1.vertex_type_row
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
|