From 7fa7b7372c87b831cbf4b2dc9d1ac126d95d6703 Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 20 Aug 2020 12:21:32 +0200 Subject: [PATCH] Work on loop, split and sampling. --- src/triacontagon/loop.py | 2 +- src/triacontagon/sampling.py | 43 ++++++++++++++---- src/triacontagon/split.py | 42 ++++++++++++------ tests/triacontagon/test_loop.py | 69 ++++++++++++++++++++++++++++- tests/triacontagon/test_sampling.py | 18 +++++++- 5 files changed, 147 insertions(+), 27 deletions(-) diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py index f52af87..e5f96bf 100644 --- a/src/triacontagon/loop.py +++ b/src/triacontagon/loop.py @@ -43,7 +43,7 @@ class TrainLoop(object): self.model = model self.test_data = test_data self.initial_repr = list(initial_repr) - self.max_epochs = int(num_epochs) + self.max_epochs = int(max_epochs) self.batch_size = int(batch_size) self.loss = loss self.lr = float(lr) diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py index c85ee1d..cc30402 100644 --- a/src/triacontagon/sampling.py +++ b/src/triacontagon/sampling.py @@ -20,7 +20,7 @@ def fixed_unigram_candidate_sampler( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, - distortion: float = 1.): + distortion: float = 1.) -> torch.Tensor: if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') @@ -29,26 +29,34 @@ def fixed_unigram_candidate_sampler( raise ValueError('num_repeats must be 1D') num_rows = true_classes.shape[0] + print('true_classes.shape:', true_classes.shape) # unigrams = np.array(unigrams) if distortion != 1.: unigrams = unigrams.to(torch.float64) ** distortion - # print('unigrams:', unigrams) + print('unigrams:', unigrams) + indices = torch.arange(num_rows) indices = torch.repeat_interleave(indices, num_repeats) + indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), + indices.view(-1, 1) ], dim=1) + num_samples = len(indices) result = torch.zeros(num_samples, dtype=torch.long) + print('num_rows:', num_rows, 'num_samples:', num_samples) + while len(indices) > 0: - # print('len(indices):', len(indices)) + print('len(indices):', len(indices)) + print('indices:', indices) sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) candidates = torch.tensor(list(sampler)) candidates = candidates.view(len(indices), 1) - # print('candidates:', candidates) - # print('true_classes:', true_classes[indices, :]) - result[indices] = candidates.transpose(0, 1) - # print('result:', result) - mask = (candidates == true_classes[indices, :]) + print('candidates:', candidates) + print('true_classes:', true_classes[indices[:, 1], :]) + result[indices[:, 0]] = candidates.transpose(0, 1) + print('result:', result) + mask = (candidates == true_classes[indices[:, 1], :]) mask = mask.sum(1).to(torch.bool) - # print('mask:', mask) + print('mask:', mask) indices = indices[mask] # result[indices] = 0 return result @@ -164,3 +172,20 @@ def negative_sample_data(data: Data) -> Data: #new_edge_types[key] = new_et #res = Data(data.vertex_types, new_edge_types) return res + + +def merge_data(pos_data: Data, neg_data: Data) -> Data: + assert isinstance(pos_data, Data) + assert isinstance(neg_data, Data) + + res = PosNegData() + + for vt in pos_data.vertex_types: + res.add_vertex_type(vt.name, vt.count) + + for key, pos_et in pos_data.edge_types.items(): + neg_et = neg_data.edge_types[key] + res.add_edge_type(pos_et.name, + pos_et.vertex_type_row, pos_et.vertex_type_column, + pos_et.adjacency_matrices, neg_et.adjacency_matrices, + pos_et.decoder_factory) diff --git a/src/triacontagon/split.py b/src/triacontagon/split.py index dd7a12c..68826f1 100644 --- a/src/triacontagon/split.py +++ b/src/triacontagon/split.py @@ -1,8 +1,9 @@ from .data import Data, \ - TrainingBatch, \ EdgeType -from typing import Tuple +from typing import Tuple, \ + List from .util import _sparse_coo_tensor +import torch def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): @@ -17,21 +18,30 @@ def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): ofs = 0 res = [] for r in ratios: - cnt = r * len(values) - ind = indices[:, ofs:ofs+cnt] - val = values[ofs:ofs+cnt] + # cnt = r * len(values) + + beg = int(ofs * len(values)) + end = int((ofs + r) * len(values)) + ofs += r + + ind = indices[:, beg:end] + val = values[beg:end] res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) - ofs += cnt + # ofs += cnt return res def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): - res = [ [] for _ in range(len(et.adjacency_matrices)) ] + res = [ split_adj_mat(adj_mat, ratios) \ + for adj_mat in et.adjacency_matrices ] - for adj_mat in et.adjacency_matrices: - for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)): - res[i].append(new_adj_mat) + res = [ EdgeType(et.name, + et.vertex_type_row, + et.vertex_type_column, + [ a[i] for a in res ], + et.decoder_factory, + None ) for i in range(len(ratios)) ] return res @@ -49,11 +59,15 @@ def split_data(data: Data, res = [ {} for _ in range(len(ratios)) ] - for key, et in data.edge_types: + for key, et in data.edge_types.items(): for i, new_et in enumerate(split_edge_type(et, ratios)): res[i][key] = new_et - res = [ Data(data.vertex_types, new_edge_types) \ - for new_edge_types in res ] + res_1 = [] + for new_edge_types in res: + d = Data() + d.vertex_types = data.vertex_types, + d.edge_types = new_edge_types + res_1.append(d) - return res + return res_1 diff --git a/tests/triacontagon/test_loop.py b/tests/triacontagon/test_loop.py index dde1299..7456ab5 100644 --- a/tests/triacontagon/test_loop.py +++ b/tests/triacontagon/test_loop.py @@ -1,5 +1,11 @@ -from triacontagon.loop import _merge_pos_neg_batches -from triacontagon.model import TrainingBatch +from triacontagon.loop import _merge_pos_neg_batches, \ + TrainLoop +from triacontagon.model import TrainingBatch, \ + Model +from triacontagon.data import Data +from triacontagon.decode import dedicom_decoder +from triacontagon.util import common_one_hot_encoding +from triacontagon.split import split_data import torch import pytest @@ -64,3 +70,62 @@ def test_merge_pos_neg_batches_02(): print(b_1) with pytest.raises(AssertionError): _ = _merge_pos_neg_batches(b_1, b_2) + + +def test_train_loop_01(): + data = Data() + data.add_vertex_type('Foo', 5) + data.add_vertex_type('Bar', 4) + + foo_foo = torch.tensor([ + [0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [1, 0, 0, 1, 0], + [0, 0, 1, 0, 1], + [0, 1, 0, 0, 0] + ]) + foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 + + foo_bar = torch.tensor([ + [0, 1, 0, 1], + [0, 0, 0, 1], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 1] + ]) + bar_foo = foo_bar.transpose(0, 1) + + bar_bar = torch.tensor([ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 1], + [0, 1, 0, 0], + ]) + bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 + + data.add_edge_type('Foo-Foo', 0, 0, [ + foo_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Foo-Bar', 0, 1, [ + foo_bar.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Foo', 1, 0, [ + bar_foo.to_sparse().coalesce() + ], dedicom_decoder) + data.add_edge_type('Bar-Bar', 1, 1, [ + bar_bar.to_sparse().coalesce() + ], dedicom_decoder) + + initial_repr = common_one_hot_encoding([5, 4]) + + model = Model(data, [9, 3, 6], + keep_prob=1.0, + conv_activation=torch.sigmoid, + dec_activation=torch.sigmoid) + + train_data, val_data, test_data = split_data(data, (.9, .1, .0) ) + + loop = TrainLoop(model, val_data, test_data, initial_repr, + max_epochs=1, batch_size=1) + + _ = loop.run() diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py index 6bba237..0b45769 100644 --- a/tests/triacontagon/test_sampling.py +++ b/tests/triacontagon/test_sampling.py @@ -1,5 +1,6 @@ from triacontagon.data import Data -from triacontagon.sampling import get_true_classes, \ +from triacontagon.sampling import fixed_unigram_candidate_sampler, \ + get_true_classes, \ negative_sample_adj_mat, \ negative_sample_data from triacontagon.decode import dedicom_decoder @@ -7,6 +8,21 @@ import torch import time +def test_fixed_unigram_candidate_sampler_01(): + true_classes = torch.tensor([[-1], + [-1], + [ 3], + [ 2], + [-1]]) + num_repeats = torch.tensor([0, 0, 1, 1, 0]) + unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64) + distortion = 0.75 + res = fixed_unigram_candidate_sampler(true_classes, num_repeats, + unigrams, distortion) + print('res:', res) + + + def test_get_true_classes_01(): adj_mat = torch.tensor([ [0, 1, 0, 1, 0],