|
- from triacontagon.loop import _merge_pos_neg_batches, \
- TrainLoop
- from triacontagon.model import TrainingBatch, \
- Model
- from triacontagon.data import Data
- from triacontagon.decode import dedicom_decoder
- from triacontagon.util import common_one_hot_encoding
- from triacontagon.split import split_data
- import torch
- import pytest
-
-
- def test_merge_pos_neg_batches_01():
- b_1 = TrainingBatch(0, 0, 0, torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6]
- ]), torch.ones(4))
- b_2 = TrainingBatch(0, 0, 0, torch.tensor([
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]), torch.zeros(4))
- b = _merge_pos_neg_batches(b_1, b_2)
-
- assert b.vertex_type_row == 0
- assert b.vertex_type_column == 0
- assert b.relation_type_index == 0
- assert torch.all(b.edges == torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6],
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]))
- assert torch.all(b.target_values == \
- torch.cat([ torch.ones(4), torch.zeros(4) ]))
-
-
- def test_merge_pos_neg_batches_02():
- b_1 = TrainingBatch(0, 1, 0, torch.tensor([
- [0, 1],
- [2, 3],
- [4, 5],
- [5, 6]
- ]), torch.ones(4))
- b_2 = TrainingBatch(0, 0, 0, torch.tensor([
- [1, 6],
- [3, 5],
- [5, 2],
- [4, 1]
- ]), torch.zeros(4))
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
-
- b_1.vertex_type_row, b_1.vertex_type_column = \
- b_1.vertex_type_column, b_1.vertex_type_row
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
-
- b_1.vertex_type_row, b_1.relation_type_index = \
- b_1.relation_type_index, b_1.vertex_type_row
- print(b_1)
- with pytest.raises(AssertionError):
- _ = _merge_pos_neg_batches(b_1, b_2)
-
-
- def test_train_loop_01():
- data = Data()
- data.add_vertex_type('Foo', 5)
- data.add_vertex_type('Bar', 4)
-
- foo_foo = torch.tensor([
- [0, 0, 1, 0, 0],
- [0, 0, 0, 1, 0],
- [1, 0, 0, 0, 0],
- [0, 1, 0, 0, 0],
- [0, 0, 0, 0, 0]
- ], dtype=torch.float32)
- foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
-
- foo_bar = torch.tensor([
- [0, 0, 1, 0],
- [0, 0, 0, 1],
- [0, 1, 0, 0],
- [1, 0, 0, 0],
- [0, 0, 0, 1]
- ], dtype=torch.float32)
- bar_foo = foo_bar.transpose(0, 1)
-
- bar_bar = torch.tensor([
- [0, 1, 0, 0],
- [1, 0, 0, 0],
- [0, 0, 0, 1],
- [0, 0, 1, 0],
- ], dtype=torch.float32)
- bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
-
- data.add_edge_type('Foo-Foo', 0, 0, [
- foo_foo.to_sparse().coalesce()
- ], dedicom_decoder)
- data.add_edge_type('Foo-Bar', 0, 1, [
- foo_bar.to_sparse().coalesce()
- ], dedicom_decoder)
- data.add_edge_type('Bar-Foo', 1, 0, [
- bar_foo.to_sparse().coalesce()
- ], dedicom_decoder)
- data.add_edge_type('Bar-Bar', 1, 1, [
- bar_bar.to_sparse().coalesce()
- ], dedicom_decoder)
-
- initial_repr = common_one_hot_encoding([5, 4])
-
- model = Model(data, [9, 3, 6],
- keep_prob=1.0,
- conv_activation=torch.sigmoid,
- dec_activation=torch.sigmoid)
-
- train_data, val_data, test_data = split_data(data, (.5, .5, .0) )
-
- print('val_data:', val_data)
- print('val_data.vertex_types:', val_data.vertex_types)
-
- loop = TrainLoop(model, val_data, test_data, initial_repr,
- max_epochs=1, batch_size=1)
-
- _ = loop.run()
|