| @@ -43,7 +43,7 @@ class TrainLoop(object): | |||||
| self.model = model | self.model = model | ||||
| self.test_data = test_data | self.test_data = test_data | ||||
| self.initial_repr = list(initial_repr) | self.initial_repr = list(initial_repr) | ||||
| self.max_epochs = int(num_epochs) | |||||
| self.max_epochs = int(max_epochs) | |||||
| self.batch_size = int(batch_size) | self.batch_size = int(batch_size) | ||||
| self.loss = loss | self.loss = loss | ||||
| self.lr = float(lr) | self.lr = float(lr) | ||||
| @@ -20,7 +20,7 @@ def fixed_unigram_candidate_sampler( | |||||
| true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
| num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
| unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||
| distortion: float = 1.): | |||||
| distortion: float = 1.) -> torch.Tensor: | |||||
| if len(true_classes.shape) != 2: | if len(true_classes.shape) != 2: | ||||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | ||||
| @@ -29,26 +29,34 @@ def fixed_unigram_candidate_sampler( | |||||
| raise ValueError('num_repeats must be 1D') | raise ValueError('num_repeats must be 1D') | ||||
| num_rows = true_classes.shape[0] | num_rows = true_classes.shape[0] | ||||
| print('true_classes.shape:', true_classes.shape) | |||||
| # unigrams = np.array(unigrams) | # unigrams = np.array(unigrams) | ||||
| if distortion != 1.: | if distortion != 1.: | ||||
| unigrams = unigrams.to(torch.float64) ** distortion | unigrams = unigrams.to(torch.float64) ** distortion | ||||
| # print('unigrams:', unigrams) | |||||
| print('unigrams:', unigrams) | |||||
| indices = torch.arange(num_rows) | indices = torch.arange(num_rows) | ||||
| indices = torch.repeat_interleave(indices, num_repeats) | indices = torch.repeat_interleave(indices, num_repeats) | ||||
| indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||||
| indices.view(-1, 1) ], dim=1) | |||||
| num_samples = len(indices) | num_samples = len(indices) | ||||
| result = torch.zeros(num_samples, dtype=torch.long) | result = torch.zeros(num_samples, dtype=torch.long) | ||||
| print('num_rows:', num_rows, 'num_samples:', num_samples) | |||||
| while len(indices) > 0: | while len(indices) > 0: | ||||
| # print('len(indices):', len(indices)) | |||||
| print('len(indices):', len(indices)) | |||||
| print('indices:', indices) | |||||
| sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | ||||
| candidates = torch.tensor(list(sampler)) | candidates = torch.tensor(list(sampler)) | ||||
| candidates = candidates.view(len(indices), 1) | candidates = candidates.view(len(indices), 1) | ||||
| # print('candidates:', candidates) | |||||
| # print('true_classes:', true_classes[indices, :]) | |||||
| result[indices] = candidates.transpose(0, 1) | |||||
| # print('result:', result) | |||||
| mask = (candidates == true_classes[indices, :]) | |||||
| print('candidates:', candidates) | |||||
| print('true_classes:', true_classes[indices[:, 1], :]) | |||||
| result[indices[:, 0]] = candidates.transpose(0, 1) | |||||
| print('result:', result) | |||||
| mask = (candidates == true_classes[indices[:, 1], :]) | |||||
| mask = mask.sum(1).to(torch.bool) | mask = mask.sum(1).to(torch.bool) | ||||
| # print('mask:', mask) | |||||
| print('mask:', mask) | |||||
| indices = indices[mask] | indices = indices[mask] | ||||
| # result[indices] = 0 | # result[indices] = 0 | ||||
| return result | return result | ||||
| @@ -164,3 +172,20 @@ def negative_sample_data(data: Data) -> Data: | |||||
| #new_edge_types[key] = new_et | #new_edge_types[key] = new_et | ||||
| #res = Data(data.vertex_types, new_edge_types) | #res = Data(data.vertex_types, new_edge_types) | ||||
| return res | return res | ||||
| def merge_data(pos_data: Data, neg_data: Data) -> Data: | |||||
| assert isinstance(pos_data, Data) | |||||
| assert isinstance(neg_data, Data) | |||||
| res = PosNegData() | |||||
| for vt in pos_data.vertex_types: | |||||
| res.add_vertex_type(vt.name, vt.count) | |||||
| for key, pos_et in pos_data.edge_types.items(): | |||||
| neg_et = neg_data.edge_types[key] | |||||
| res.add_edge_type(pos_et.name, | |||||
| pos_et.vertex_type_row, pos_et.vertex_type_column, | |||||
| pos_et.adjacency_matrices, neg_et.adjacency_matrices, | |||||
| pos_et.decoder_factory) | |||||
| @@ -1,8 +1,9 @@ | |||||
| from .data import Data, \ | from .data import Data, \ | ||||
| TrainingBatch, \ | |||||
| EdgeType | EdgeType | ||||
| from typing import Tuple | |||||
| from typing import Tuple, \ | |||||
| List | |||||
| from .util import _sparse_coo_tensor | from .util import _sparse_coo_tensor | ||||
| import torch | |||||
| def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | ||||
| @@ -17,21 +18,30 @@ def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | |||||
| ofs = 0 | ofs = 0 | ||||
| res = [] | res = [] | ||||
| for r in ratios: | for r in ratios: | ||||
| cnt = r * len(values) | |||||
| ind = indices[:, ofs:ofs+cnt] | |||||
| val = values[ofs:ofs+cnt] | |||||
| # cnt = r * len(values) | |||||
| beg = int(ofs * len(values)) | |||||
| end = int((ofs + r) * len(values)) | |||||
| ofs += r | |||||
| ind = indices[:, beg:end] | |||||
| val = values[beg:end] | |||||
| res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) | res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) | ||||
| ofs += cnt | |||||
| # ofs += cnt | |||||
| return res | return res | ||||
| def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): | def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): | ||||
| res = [ [] for _ in range(len(et.adjacency_matrices)) ] | |||||
| res = [ split_adj_mat(adj_mat, ratios) \ | |||||
| for adj_mat in et.adjacency_matrices ] | |||||
| for adj_mat in et.adjacency_matrices: | |||||
| for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)): | |||||
| res[i].append(new_adj_mat) | |||||
| res = [ EdgeType(et.name, | |||||
| et.vertex_type_row, | |||||
| et.vertex_type_column, | |||||
| [ a[i] for a in res ], | |||||
| et.decoder_factory, | |||||
| None ) for i in range(len(ratios)) ] | |||||
| return res | return res | ||||
| @@ -49,11 +59,15 @@ def split_data(data: Data, | |||||
| res = [ {} for _ in range(len(ratios)) ] | res = [ {} for _ in range(len(ratios)) ] | ||||
| for key, et in data.edge_types: | |||||
| for key, et in data.edge_types.items(): | |||||
| for i, new_et in enumerate(split_edge_type(et, ratios)): | for i, new_et in enumerate(split_edge_type(et, ratios)): | ||||
| res[i][key] = new_et | res[i][key] = new_et | ||||
| res = [ Data(data.vertex_types, new_edge_types) \ | |||||
| for new_edge_types in res ] | |||||
| res_1 = [] | |||||
| for new_edge_types in res: | |||||
| d = Data() | |||||
| d.vertex_types = data.vertex_types, | |||||
| d.edge_types = new_edge_types | |||||
| res_1.append(d) | |||||
| return res | |||||
| return res_1 | |||||
| @@ -1,5 +1,11 @@ | |||||
| from triacontagon.loop import _merge_pos_neg_batches | |||||
| from triacontagon.model import TrainingBatch | |||||
| from triacontagon.loop import _merge_pos_neg_batches, \ | |||||
| TrainLoop | |||||
| from triacontagon.model import TrainingBatch, \ | |||||
| Model | |||||
| from triacontagon.data import Data | |||||
| from triacontagon.decode import dedicom_decoder | |||||
| from triacontagon.util import common_one_hot_encoding | |||||
| from triacontagon.split import split_data | |||||
| import torch | import torch | ||||
| import pytest | import pytest | ||||
| @@ -64,3 +70,62 @@ def test_merge_pos_neg_batches_02(): | |||||
| print(b_1) | print(b_1) | ||||
| with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
| _ = _merge_pos_neg_batches(b_1, b_2) | _ = _merge_pos_neg_batches(b_1, b_2) | ||||
| def test_train_loop_01(): | |||||
| data = Data() | |||||
| data.add_vertex_type('Foo', 5) | |||||
| data.add_vertex_type('Bar', 4) | |||||
| foo_foo = torch.tensor([ | |||||
| [0, 0, 0, 1, 0], | |||||
| [0, 0, 1, 0, 0], | |||||
| [1, 0, 0, 1, 0], | |||||
| [0, 0, 1, 0, 1], | |||||
| [0, 1, 0, 0, 0] | |||||
| ]) | |||||
| foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 | |||||
| foo_bar = torch.tensor([ | |||||
| [0, 1, 0, 1], | |||||
| [0, 0, 0, 1], | |||||
| [0, 1, 0, 0], | |||||
| [1, 0, 0, 0], | |||||
| [0, 0, 1, 1] | |||||
| ]) | |||||
| bar_foo = foo_bar.transpose(0, 1) | |||||
| bar_bar = torch.tensor([ | |||||
| [0, 0, 1, 0], | |||||
| [1, 0, 0, 0], | |||||
| [0, 1, 0, 1], | |||||
| [0, 1, 0, 0], | |||||
| ]) | |||||
| bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 | |||||
| data.add_edge_type('Foo-Foo', 0, 0, [ | |||||
| foo_foo.to_sparse().coalesce() | |||||
| ], dedicom_decoder) | |||||
| data.add_edge_type('Foo-Bar', 0, 1, [ | |||||
| foo_bar.to_sparse().coalesce() | |||||
| ], dedicom_decoder) | |||||
| data.add_edge_type('Bar-Foo', 1, 0, [ | |||||
| bar_foo.to_sparse().coalesce() | |||||
| ], dedicom_decoder) | |||||
| data.add_edge_type('Bar-Bar', 1, 1, [ | |||||
| bar_bar.to_sparse().coalesce() | |||||
| ], dedicom_decoder) | |||||
| initial_repr = common_one_hot_encoding([5, 4]) | |||||
| model = Model(data, [9, 3, 6], | |||||
| keep_prob=1.0, | |||||
| conv_activation=torch.sigmoid, | |||||
| dec_activation=torch.sigmoid) | |||||
| train_data, val_data, test_data = split_data(data, (.9, .1, .0) ) | |||||
| loop = TrainLoop(model, val_data, test_data, initial_repr, | |||||
| max_epochs=1, batch_size=1) | |||||
| _ = loop.run() | |||||
| @@ -1,5 +1,6 @@ | |||||
| from triacontagon.data import Data | from triacontagon.data import Data | ||||
| from triacontagon.sampling import get_true_classes, \ | |||||
| from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||||
| get_true_classes, \ | |||||
| negative_sample_adj_mat, \ | negative_sample_adj_mat, \ | ||||
| negative_sample_data | negative_sample_data | ||||
| from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
| @@ -7,6 +8,21 @@ import torch | |||||
| import time | import time | ||||
| def test_fixed_unigram_candidate_sampler_01(): | |||||
| true_classes = torch.tensor([[-1], | |||||
| [-1], | |||||
| [ 3], | |||||
| [ 2], | |||||
| [-1]]) | |||||
| num_repeats = torch.tensor([0, 0, 1, 1, 0]) | |||||
| unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64) | |||||
| distortion = 0.75 | |||||
| res = fixed_unigram_candidate_sampler(true_classes, num_repeats, | |||||
| unigrams, distortion) | |||||
| print('res:', res) | |||||
| def test_get_true_classes_01(): | def test_get_true_classes_01(): | ||||
| adj_mat = torch.tensor([ | adj_mat = torch.tensor([ | ||||
| [0, 1, 0, 1, 0], | [0, 1, 0, 1, 0], | ||||