diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py index cc20b5a..f52af87 100644 --- a/src/triacontagon/loop.py +++ b/src/triacontagon/loop.py @@ -1,7 +1,11 @@ -from .model import Model +from .model import Model, \ + TrainingBatch from .batch import Batcher from .sampling import negative_sample_data from .data import Data +import torch +from typing import List, \ + Callable def _merge_pos_neg_batches(pos_batch, neg_batch): diff --git a/tests/triacontagon/test_loop.py b/tests/triacontagon/test_loop.py new file mode 100644 index 0000000..dde1299 --- /dev/null +++ b/tests/triacontagon/test_loop.py @@ -0,0 +1,66 @@ +from triacontagon.loop import _merge_pos_neg_batches +from triacontagon.model import TrainingBatch +import torch +import pytest + + +def test_merge_pos_neg_batches_01(): + b_1 = TrainingBatch(0, 0, 0, torch.tensor([ + [0, 1], + [2, 3], + [4, 5], + [5, 6] + ]), torch.ones(4)) + b_2 = TrainingBatch(0, 0, 0, torch.tensor([ + [1, 6], + [3, 5], + [5, 2], + [4, 1] + ]), torch.zeros(4)) + b = _merge_pos_neg_batches(b_1, b_2) + + assert b.vertex_type_row == 0 + assert b.vertex_type_column == 0 + assert b.relation_type_index == 0 + assert torch.all(b.edges == torch.tensor([ + [0, 1], + [2, 3], + [4, 5], + [5, 6], + [1, 6], + [3, 5], + [5, 2], + [4, 1] + ])) + assert torch.all(b.target_values == \ + torch.cat([ torch.ones(4), torch.zeros(4) ])) + + +def test_merge_pos_neg_batches_02(): + b_1 = TrainingBatch(0, 1, 0, torch.tensor([ + [0, 1], + [2, 3], + [4, 5], + [5, 6] + ]), torch.ones(4)) + b_2 = TrainingBatch(0, 0, 0, torch.tensor([ + [1, 6], + [3, 5], + [5, 2], + [4, 1] + ]), torch.zeros(4)) + print(b_1) + with pytest.raises(AssertionError): + _ = _merge_pos_neg_batches(b_1, b_2) + + b_1.vertex_type_row, b_1.vertex_type_column = \ + b_1.vertex_type_column, b_1.vertex_type_row + print(b_1) + with pytest.raises(AssertionError): + _ = _merge_pos_neg_batches(b_1, b_2) + + b_1.vertex_type_row, b_1.relation_type_index = \ + b_1.relation_type_index, b_1.vertex_type_row + print(b_1) + with pytest.raises(AssertionError): + _ = _merge_pos_neg_batches(b_1, b_2)