| @@ -173,7 +173,7 @@ class Model(torch.nn.Module): | |||
| for x in next_layer_repr ] | |||
| cur_layer_repr = next_layer_repr | |||
| return next_layer_repr | |||
| return cur_layer_repr | |||
| def decode(self, last_layer_repr: List[torch.Tensor], | |||
| batch: TrainingBatch) -> torch.Tensor: | |||
| @@ -182,12 +182,27 @@ class Model(torch.nn.Module): | |||
| vt_col = batch.vertex_type_column | |||
| rel_idx = batch.relation_type_index | |||
| global_interaction = \ | |||
| self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||
| self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)] | |||
| local_variation = \ | |||
| self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||
| self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)] | |||
| in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||
| in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||
| in_row = last_layer_repr[vt_row] | |||
| in_col = last_layer_repr[vt_col] | |||
| if in_row.is_sparse or in_col.is_sparse: | |||
| raise ValueError('Inputs to Model.decode() must be dense') | |||
| in_row = in_row[batch.edges[:, 0]] | |||
| in_col = in_col[batch.edges[:, 1]] | |||
| in_row = dropout(in_row, self.keep_prob) | |||
| in_col = dropout(in_col, self.keep_prob) | |||
| # in_row = in_row.to_dense() | |||
| # in_col = in_col.to_dense() | |||
| print('in_row.is_sparse:', in_row.is_sparse) | |||
| print('in_col.is_sparse:', in_col.is_sparse) | |||
| x = torch.mm(in_row, local_variation) | |||
| x = torch.mm(x, global_interaction) | |||
| @@ -197,7 +212,7 @@ class Model(torch.nn.Module): | |||
| x = torch.flatten(x) | |||
| x = self.dec_activation(x) | |||
| return x | |||
| @@ -199,6 +199,32 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||
| return torch.mm(a, b) | |||
| def _select_rows(a: torch.Tensor, rows: torch.Tensor): | |||
| if not a.is_sparse: | |||
| return a[rows] | |||
| indices = a.indices() | |||
| values = a.values() | |||
| mask = torch.zeros(a.shape[0]) | |||
| mask[rows] = 1 | |||
| if mask.sum() != len(rows): | |||
| raise ValueError('Rows must be unique') | |||
| mask = mask[indices[0]] | |||
| mask = torch.nonzero(mask, as_tuple=True)[0] | |||
| new_rows[rows] = torch.arange(len(rows)) | |||
| new_rows = new_rows[indices[0]] | |||
| indices = indices[:, mask] | |||
| indices[0] = new_rows | |||
| values = values[mask] | |||
| res = _sparse_coo_tensor(indices, values, | |||
| size=(len(rows), a.shape[1])) | |||
| return res | |||
| def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||
| List[torch.Tensor]: | |||
| @@ -4,6 +4,7 @@ from triacontagon.model import Model, \ | |||
| _per_layer_required_vertices | |||
| from triacontagon.data import Data | |||
| from triacontagon.decode import dedicom_decoder | |||
| from triacontagon.util import common_one_hot_encoding | |||
| def test_per_layer_required_vertices_01(): | |||
| @@ -83,3 +84,28 @@ def test_model_convolve_01(): | |||
| ] | |||
| _ = model.convolve(in_layer_repr) | |||
| def test_model_decode_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 100) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ | |||
| torch.rand(100, 100).round().to_sparse() | |||
| ], dedicom_decoder) | |||
| b = TrainingBatch(0, 0, 0, torch.tensor([ | |||
| [0, 1], | |||
| [10, 51], | |||
| [50, 60], | |||
| [70, 90], | |||
| [98, 99] | |||
| ]), torch.ones(5)) | |||
| in_repr = common_one_hot_encoding([100]) | |||
| in_repr = [ in_repr[0].to_dense() ] | |||
| m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid) | |||
| _ = m.decode(in_repr, b) | |||