from triacontagon.loop import _merge_pos_neg_batches from triacontagon.model import TrainingBatch import torch import pytest def test_merge_pos_neg_batches_01(): b_1 = TrainingBatch(0, 0, 0, torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6] ]), torch.ones(4)) b_2 = TrainingBatch(0, 0, 0, torch.tensor([ [1, 6], [3, 5], [5, 2], [4, 1] ]), torch.zeros(4)) b = _merge_pos_neg_batches(b_1, b_2) assert b.vertex_type_row == 0 assert b.vertex_type_column == 0 assert b.relation_type_index == 0 assert torch.all(b.edges == torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6], [1, 6], [3, 5], [5, 2], [4, 1] ])) assert torch.all(b.target_values == \ torch.cat([ torch.ones(4), torch.zeros(4) ])) def test_merge_pos_neg_batches_02(): b_1 = TrainingBatch(0, 1, 0, torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6] ]), torch.ones(4)) b_2 = TrainingBatch(0, 0, 0, torch.tensor([ [1, 6], [3, 5], [5, 2], [4, 1] ]), torch.zeros(4)) print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) b_1.vertex_type_row, b_1.vertex_type_column = \ b_1.vertex_type_column, b_1.vertex_type_row print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) b_1.vertex_type_row, b_1.relation_type_index = \ b_1.relation_type_index, b_1.vertex_type_row print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2)