@@ -43,7 +43,7 @@ class TrainLoop(object): | |||||
self.model = model | self.model = model | ||||
self.test_data = test_data | self.test_data = test_data | ||||
self.initial_repr = list(initial_repr) | self.initial_repr = list(initial_repr) | ||||
self.max_epochs = int(num_epochs) | |||||
self.max_epochs = int(max_epochs) | |||||
self.batch_size = int(batch_size) | self.batch_size = int(batch_size) | ||||
self.loss = loss | self.loss = loss | ||||
self.lr = float(lr) | self.lr = float(lr) | ||||
@@ -20,7 +20,7 @@ def fixed_unigram_candidate_sampler( | |||||
true_classes: torch.Tensor, | true_classes: torch.Tensor, | ||||
num_repeats: torch.Tensor, | num_repeats: torch.Tensor, | ||||
unigrams: torch.Tensor, | unigrams: torch.Tensor, | ||||
distortion: float = 1.): | |||||
distortion: float = 1.) -> torch.Tensor: | |||||
if len(true_classes.shape) != 2: | if len(true_classes.shape) != 2: | ||||
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | ||||
@@ -29,26 +29,34 @@ def fixed_unigram_candidate_sampler( | |||||
raise ValueError('num_repeats must be 1D') | raise ValueError('num_repeats must be 1D') | ||||
num_rows = true_classes.shape[0] | num_rows = true_classes.shape[0] | ||||
print('true_classes.shape:', true_classes.shape) | |||||
# unigrams = np.array(unigrams) | # unigrams = np.array(unigrams) | ||||
if distortion != 1.: | if distortion != 1.: | ||||
unigrams = unigrams.to(torch.float64) ** distortion | unigrams = unigrams.to(torch.float64) ** distortion | ||||
# print('unigrams:', unigrams) | |||||
print('unigrams:', unigrams) | |||||
indices = torch.arange(num_rows) | indices = torch.arange(num_rows) | ||||
indices = torch.repeat_interleave(indices, num_repeats) | indices = torch.repeat_interleave(indices, num_repeats) | ||||
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), | |||||
indices.view(-1, 1) ], dim=1) | |||||
num_samples = len(indices) | num_samples = len(indices) | ||||
result = torch.zeros(num_samples, dtype=torch.long) | result = torch.zeros(num_samples, dtype=torch.long) | ||||
print('num_rows:', num_rows, 'num_samples:', num_samples) | |||||
while len(indices) > 0: | while len(indices) > 0: | ||||
# print('len(indices):', len(indices)) | |||||
print('len(indices):', len(indices)) | |||||
print('indices:', indices) | |||||
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | ||||
candidates = torch.tensor(list(sampler)) | candidates = torch.tensor(list(sampler)) | ||||
candidates = candidates.view(len(indices), 1) | candidates = candidates.view(len(indices), 1) | ||||
# print('candidates:', candidates) | |||||
# print('true_classes:', true_classes[indices, :]) | |||||
result[indices] = candidates.transpose(0, 1) | |||||
# print('result:', result) | |||||
mask = (candidates == true_classes[indices, :]) | |||||
print('candidates:', candidates) | |||||
print('true_classes:', true_classes[indices[:, 1], :]) | |||||
result[indices[:, 0]] = candidates.transpose(0, 1) | |||||
print('result:', result) | |||||
mask = (candidates == true_classes[indices[:, 1], :]) | |||||
mask = mask.sum(1).to(torch.bool) | mask = mask.sum(1).to(torch.bool) | ||||
# print('mask:', mask) | |||||
print('mask:', mask) | |||||
indices = indices[mask] | indices = indices[mask] | ||||
# result[indices] = 0 | # result[indices] = 0 | ||||
return result | return result | ||||
@@ -164,3 +172,20 @@ def negative_sample_data(data: Data) -> Data: | |||||
#new_edge_types[key] = new_et | #new_edge_types[key] = new_et | ||||
#res = Data(data.vertex_types, new_edge_types) | #res = Data(data.vertex_types, new_edge_types) | ||||
return res | return res | ||||
def merge_data(pos_data: Data, neg_data: Data) -> Data: | |||||
assert isinstance(pos_data, Data) | |||||
assert isinstance(neg_data, Data) | |||||
res = PosNegData() | |||||
for vt in pos_data.vertex_types: | |||||
res.add_vertex_type(vt.name, vt.count) | |||||
for key, pos_et in pos_data.edge_types.items(): | |||||
neg_et = neg_data.edge_types[key] | |||||
res.add_edge_type(pos_et.name, | |||||
pos_et.vertex_type_row, pos_et.vertex_type_column, | |||||
pos_et.adjacency_matrices, neg_et.adjacency_matrices, | |||||
pos_et.decoder_factory) |
@@ -1,8 +1,9 @@ | |||||
from .data import Data, \ | from .data import Data, \ | ||||
TrainingBatch, \ | |||||
EdgeType | EdgeType | ||||
from typing import Tuple | |||||
from typing import Tuple, \ | |||||
List | |||||
from .util import _sparse_coo_tensor | from .util import _sparse_coo_tensor | ||||
import torch | |||||
def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | ||||
@@ -17,21 +18,30 @@ def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): | |||||
ofs = 0 | ofs = 0 | ||||
res = [] | res = [] | ||||
for r in ratios: | for r in ratios: | ||||
cnt = r * len(values) | |||||
ind = indices[:, ofs:ofs+cnt] | |||||
val = values[ofs:ofs+cnt] | |||||
# cnt = r * len(values) | |||||
beg = int(ofs * len(values)) | |||||
end = int((ofs + r) * len(values)) | |||||
ofs += r | |||||
ind = indices[:, beg:end] | |||||
val = values[beg:end] | |||||
res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) | res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) | ||||
ofs += cnt | |||||
# ofs += cnt | |||||
return res | return res | ||||
def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): | def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): | ||||
res = [ [] for _ in range(len(et.adjacency_matrices)) ] | |||||
res = [ split_adj_mat(adj_mat, ratios) \ | |||||
for adj_mat in et.adjacency_matrices ] | |||||
for adj_mat in et.adjacency_matrices: | |||||
for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)): | |||||
res[i].append(new_adj_mat) | |||||
res = [ EdgeType(et.name, | |||||
et.vertex_type_row, | |||||
et.vertex_type_column, | |||||
[ a[i] for a in res ], | |||||
et.decoder_factory, | |||||
None ) for i in range(len(ratios)) ] | |||||
return res | return res | ||||
@@ -49,11 +59,15 @@ def split_data(data: Data, | |||||
res = [ {} for _ in range(len(ratios)) ] | res = [ {} for _ in range(len(ratios)) ] | ||||
for key, et in data.edge_types: | |||||
for key, et in data.edge_types.items(): | |||||
for i, new_et in enumerate(split_edge_type(et, ratios)): | for i, new_et in enumerate(split_edge_type(et, ratios)): | ||||
res[i][key] = new_et | res[i][key] = new_et | ||||
res = [ Data(data.vertex_types, new_edge_types) \ | |||||
for new_edge_types in res ] | |||||
res_1 = [] | |||||
for new_edge_types in res: | |||||
d = Data() | |||||
d.vertex_types = data.vertex_types, | |||||
d.edge_types = new_edge_types | |||||
res_1.append(d) | |||||
return res | |||||
return res_1 |
@@ -1,5 +1,11 @@ | |||||
from triacontagon.loop import _merge_pos_neg_batches | |||||
from triacontagon.model import TrainingBatch | |||||
from triacontagon.loop import _merge_pos_neg_batches, \ | |||||
TrainLoop | |||||
from triacontagon.model import TrainingBatch, \ | |||||
Model | |||||
from triacontagon.data import Data | |||||
from triacontagon.decode import dedicom_decoder | |||||
from triacontagon.util import common_one_hot_encoding | |||||
from triacontagon.split import split_data | |||||
import torch | import torch | ||||
import pytest | import pytest | ||||
@@ -64,3 +70,62 @@ def test_merge_pos_neg_batches_02(): | |||||
print(b_1) | print(b_1) | ||||
with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
_ = _merge_pos_neg_batches(b_1, b_2) | _ = _merge_pos_neg_batches(b_1, b_2) | ||||
def test_train_loop_01(): | |||||
data = Data() | |||||
data.add_vertex_type('Foo', 5) | |||||
data.add_vertex_type('Bar', 4) | |||||
foo_foo = torch.tensor([ | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[1, 0, 0, 1, 0], | |||||
[0, 0, 1, 0, 1], | |||||
[0, 1, 0, 0, 0] | |||||
]) | |||||
foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2 | |||||
foo_bar = torch.tensor([ | |||||
[0, 1, 0, 1], | |||||
[0, 0, 0, 1], | |||||
[0, 1, 0, 0], | |||||
[1, 0, 0, 0], | |||||
[0, 0, 1, 1] | |||||
]) | |||||
bar_foo = foo_bar.transpose(0, 1) | |||||
bar_bar = torch.tensor([ | |||||
[0, 0, 1, 0], | |||||
[1, 0, 0, 0], | |||||
[0, 1, 0, 1], | |||||
[0, 1, 0, 0], | |||||
]) | |||||
bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2 | |||||
data.add_edge_type('Foo-Foo', 0, 0, [ | |||||
foo_foo.to_sparse().coalesce() | |||||
], dedicom_decoder) | |||||
data.add_edge_type('Foo-Bar', 0, 1, [ | |||||
foo_bar.to_sparse().coalesce() | |||||
], dedicom_decoder) | |||||
data.add_edge_type('Bar-Foo', 1, 0, [ | |||||
bar_foo.to_sparse().coalesce() | |||||
], dedicom_decoder) | |||||
data.add_edge_type('Bar-Bar', 1, 1, [ | |||||
bar_bar.to_sparse().coalesce() | |||||
], dedicom_decoder) | |||||
initial_repr = common_one_hot_encoding([5, 4]) | |||||
model = Model(data, [9, 3, 6], | |||||
keep_prob=1.0, | |||||
conv_activation=torch.sigmoid, | |||||
dec_activation=torch.sigmoid) | |||||
train_data, val_data, test_data = split_data(data, (.9, .1, .0) ) | |||||
loop = TrainLoop(model, val_data, test_data, initial_repr, | |||||
max_epochs=1, batch_size=1) | |||||
_ = loop.run() |
@@ -1,5 +1,6 @@ | |||||
from triacontagon.data import Data | from triacontagon.data import Data | ||||
from triacontagon.sampling import get_true_classes, \ | |||||
from triacontagon.sampling import fixed_unigram_candidate_sampler, \ | |||||
get_true_classes, \ | |||||
negative_sample_adj_mat, \ | negative_sample_adj_mat, \ | ||||
negative_sample_data | negative_sample_data | ||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
@@ -7,6 +8,21 @@ import torch | |||||
import time | import time | ||||
def test_fixed_unigram_candidate_sampler_01(): | |||||
true_classes = torch.tensor([[-1], | |||||
[-1], | |||||
[ 3], | |||||
[ 2], | |||||
[-1]]) | |||||
num_repeats = torch.tensor([0, 0, 1, 1, 0]) | |||||
unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64) | |||||
distortion = 0.75 | |||||
res = fixed_unigram_candidate_sampler(true_classes, num_repeats, | |||||
unigrams, distortion) | |||||
print('res:', res) | |||||
def test_get_true_classes_01(): | def test_get_true_classes_01(): | ||||
adj_mat = torch.tensor([ | adj_mat = torch.tensor([ | ||||
[0, 1, 0, 1, 0], | [0, 1, 0, 1, 0], | ||||