@@ -173,7 +173,7 @@ class Model(torch.nn.Module): | |||
for x in next_layer_repr ] | |||
cur_layer_repr = next_layer_repr | |||
return next_layer_repr | |||
return cur_layer_repr | |||
def decode(self, last_layer_repr: List[torch.Tensor], | |||
batch: TrainingBatch) -> torch.Tensor: | |||
@@ -182,12 +182,27 @@ class Model(torch.nn.Module): | |||
vt_col = batch.vertex_type_column | |||
rel_idx = batch.relation_type_index | |||
global_interaction = \ | |||
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||
self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)] | |||
local_variation = \ | |||
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||
self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)] | |||
in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||
in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||
in_row = last_layer_repr[vt_row] | |||
in_col = last_layer_repr[vt_col] | |||
if in_row.is_sparse or in_col.is_sparse: | |||
raise ValueError('Inputs to Model.decode() must be dense') | |||
in_row = in_row[batch.edges[:, 0]] | |||
in_col = in_col[batch.edges[:, 1]] | |||
in_row = dropout(in_row, self.keep_prob) | |||
in_col = dropout(in_col, self.keep_prob) | |||
# in_row = in_row.to_dense() | |||
# in_col = in_col.to_dense() | |||
print('in_row.is_sparse:', in_row.is_sparse) | |||
print('in_col.is_sparse:', in_col.is_sparse) | |||
x = torch.mm(in_row, local_variation) | |||
x = torch.mm(x, global_interaction) | |||
@@ -197,7 +212,7 @@ class Model(torch.nn.Module): | |||
x = torch.flatten(x) | |||
x = self.dec_activation(x) | |||
return x | |||
@@ -199,6 +199,32 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||
return torch.mm(a, b) | |||
def _select_rows(a: torch.Tensor, rows: torch.Tensor): | |||
if not a.is_sparse: | |||
return a[rows] | |||
indices = a.indices() | |||
values = a.values() | |||
mask = torch.zeros(a.shape[0]) | |||
mask[rows] = 1 | |||
if mask.sum() != len(rows): | |||
raise ValueError('Rows must be unique') | |||
mask = mask[indices[0]] | |||
mask = torch.nonzero(mask, as_tuple=True)[0] | |||
new_rows[rows] = torch.arange(len(rows)) | |||
new_rows = new_rows[indices[0]] | |||
indices = indices[:, mask] | |||
indices[0] = new_rows | |||
values = values[mask] | |||
res = _sparse_coo_tensor(indices, values, | |||
size=(len(rows), a.shape[1])) | |||
return res | |||
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||
List[torch.Tensor]: | |||
@@ -4,6 +4,7 @@ from triacontagon.model import Model, \ | |||
_per_layer_required_vertices | |||
from triacontagon.data import Data | |||
from triacontagon.decode import dedicom_decoder | |||
from triacontagon.util import common_one_hot_encoding | |||
def test_per_layer_required_vertices_01(): | |||
@@ -83,3 +84,28 @@ def test_model_convolve_01(): | |||
] | |||
_ = model.convolve(in_layer_repr) | |||
def test_model_decode_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 100) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||
torch.rand(100, 100).round().to_sparse() | |||
], dedicom_decoder) | |||
b = TrainingBatch(0, 0, 0, torch.tensor([ | |||
[0, 1], | |||
[10, 51], | |||
[50, 60], | |||
[70, 90], | |||
[98, 99] | |||
]), torch.ones(5)) | |||
in_repr = common_one_hot_encoding([100]) | |||
in_repr = [ in_repr[0].to_dense() ] | |||
m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid) | |||
_ = m.decode(in_repr, b) |