@@ -173,7 +173,7 @@ class Model(torch.nn.Module): | |||||
for x in next_layer_repr ] | for x in next_layer_repr ] | ||||
cur_layer_repr = next_layer_repr | cur_layer_repr = next_layer_repr | ||||
return next_layer_repr | |||||
return cur_layer_repr | |||||
def decode(self, last_layer_repr: List[torch.Tensor], | def decode(self, last_layer_repr: List[torch.Tensor], | ||||
batch: TrainingBatch) -> torch.Tensor: | batch: TrainingBatch) -> torch.Tensor: | ||||
@@ -182,12 +182,27 @@ class Model(torch.nn.Module): | |||||
vt_col = batch.vertex_type_column | vt_col = batch.vertex_type_column | ||||
rel_idx = batch.relation_type_index | rel_idx = batch.relation_type_index | ||||
global_interaction = \ | global_interaction = \ | ||||
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||||
self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)] | |||||
local_variation = \ | local_variation = \ | ||||
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||||
self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)] | |||||
in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||||
in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||||
in_row = last_layer_repr[vt_row] | |||||
in_col = last_layer_repr[vt_col] | |||||
if in_row.is_sparse or in_col.is_sparse: | |||||
raise ValueError('Inputs to Model.decode() must be dense') | |||||
in_row = in_row[batch.edges[:, 0]] | |||||
in_col = in_col[batch.edges[:, 1]] | |||||
in_row = dropout(in_row, self.keep_prob) | |||||
in_col = dropout(in_col, self.keep_prob) | |||||
# in_row = in_row.to_dense() | |||||
# in_col = in_col.to_dense() | |||||
print('in_row.is_sparse:', in_row.is_sparse) | |||||
print('in_col.is_sparse:', in_col.is_sparse) | |||||
x = torch.mm(in_row, local_variation) | x = torch.mm(in_row, local_variation) | ||||
x = torch.mm(x, global_interaction) | x = torch.mm(x, global_interaction) | ||||
@@ -197,7 +212,7 @@ class Model(torch.nn.Module): | |||||
x = torch.flatten(x) | x = torch.flatten(x) | ||||
x = self.dec_activation(x) | x = self.dec_activation(x) | ||||
return x | return x | ||||
@@ -199,6 +199,32 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||||
return torch.mm(a, b) | return torch.mm(a, b) | ||||
def _select_rows(a: torch.Tensor, rows: torch.Tensor): | |||||
if not a.is_sparse: | |||||
return a[rows] | |||||
indices = a.indices() | |||||
values = a.values() | |||||
mask = torch.zeros(a.shape[0]) | |||||
mask[rows] = 1 | |||||
if mask.sum() != len(rows): | |||||
raise ValueError('Rows must be unique') | |||||
mask = mask[indices[0]] | |||||
mask = torch.nonzero(mask, as_tuple=True)[0] | |||||
new_rows[rows] = torch.arange(len(rows)) | |||||
new_rows = new_rows[indices[0]] | |||||
indices = indices[:, mask] | |||||
indices[0] = new_rows | |||||
values = values[mask] | |||||
res = _sparse_coo_tensor(indices, values, | |||||
size=(len(rows), a.shape[1])) | |||||
return res | |||||
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | ||||
List[torch.Tensor]: | List[torch.Tensor]: | ||||
@@ -4,6 +4,7 @@ from triacontagon.model import Model, \ | |||||
_per_layer_required_vertices | _per_layer_required_vertices | ||||
from triacontagon.data import Data | from triacontagon.data import Data | ||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
from triacontagon.util import common_one_hot_encoding | |||||
def test_per_layer_required_vertices_01(): | def test_per_layer_required_vertices_01(): | ||||
@@ -83,3 +84,28 @@ def test_model_convolve_01(): | |||||
] | ] | ||||
_ = model.convolve(in_layer_repr) | _ = model.convolve(in_layer_repr) | ||||
def test_model_decode_01(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 100) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ | |||||
torch.rand(100, 100).round().to_sparse() | |||||
], dedicom_decoder) | |||||
b = TrainingBatch(0, 0, 0, torch.tensor([ | |||||
[0, 1], | |||||
[10, 51], | |||||
[50, 60], | |||||
[70, 90], | |||||
[98, 99] | |||||
]), torch.ones(5)) | |||||
in_repr = common_one_hot_encoding([100]) | |||||
in_repr = [ in_repr[0].to_dense() ] | |||||
m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid) | |||||
_ = m.decode(in_repr, b) |