from triacontagon.loop import _merge_pos_neg_batches, \ TrainLoop from triacontagon.model import TrainingBatch, \ Model from triacontagon.data import Data from triacontagon.decode import dedicom_decoder from triacontagon.util import common_one_hot_encoding from triacontagon.split import split_data import torch import pytest def test_merge_pos_neg_batches_01(): b_1 = TrainingBatch(0, 0, 0, torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6] ]), torch.ones(4)) b_2 = TrainingBatch(0, 0, 0, torch.tensor([ [1, 6], [3, 5], [5, 2], [4, 1] ]), torch.zeros(4)) b = _merge_pos_neg_batches(b_1, b_2) assert b.vertex_type_row == 0 assert b.vertex_type_column == 0 assert b.relation_type_index == 0 assert torch.all(b.edges == torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6], [1, 6], [3, 5], [5, 2], [4, 1] ])) assert torch.all(b.target_values == \ torch.cat([ torch.ones(4), torch.zeros(4) ])) def test_merge_pos_neg_batches_02(): b_1 = TrainingBatch(0, 1, 0, torch.tensor([ [0, 1], [2, 3], [4, 5], [5, 6] ]), torch.ones(4)) b_2 = TrainingBatch(0, 0, 0, torch.tensor([ [1, 6], [3, 5], [5, 2], [4, 1] ]), torch.zeros(4)) print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) b_1.vertex_type_row, b_1.vertex_type_column = \ b_1.vertex_type_column, b_1.vertex_type_row print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) b_1.vertex_type_row, b_1.relation_type_index = \ b_1.relation_type_index, b_1.vertex_type_row print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) def test_train_loop_01(): data = Data() data.add_vertex_type('Foo', 5) data.add_vertex_type('Bar', 4) foo_foo = torch.tensor([ [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0] ], dtype=torch.float32) foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 foo_bar = torch.tensor([ [0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1] ], dtype=torch.float32) bar_foo = foo_bar.transpose(0, 1) bar_bar = torch.tensor([ [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], ], dtype=torch.float32) bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 data.add_edge_type('Foo-Foo', 0, 0, [ foo_foo.to_sparse().coalesce() ], dedicom_decoder) data.add_edge_type('Foo-Bar', 0, 1, [ foo_bar.to_sparse().coalesce() ], dedicom_decoder) data.add_edge_type('Bar-Foo', 1, 0, [ bar_foo.to_sparse().coalesce() ], dedicom_decoder) data.add_edge_type('Bar-Bar', 1, 1, [ bar_bar.to_sparse().coalesce() ], dedicom_decoder) initial_repr = common_one_hot_encoding([5, 4]) model = Model(data, [9, 3, 6], keep_prob=1.0, conv_activation=torch.sigmoid, dec_activation=torch.sigmoid) train_data, val_data, test_data = split_data(data, (.5, .5, .0) ) print('val_data:', val_data) print('val_data.vertex_types:', val_data.vertex_types) loop = TrainLoop(model, val_data, test_data, initial_repr, max_epochs=1, batch_size=1) _ = loop.run()