| @@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob): | |||
| def dropout(x, keep_prob): | |||
| if keep_prob == 1: | |||
| return x | |||
| if x.is_sparse: | |||
| return dropout_sparse(x, keep_prob) | |||
| else: | |||
| @@ -22,6 +22,7 @@ class TrainingBatch(object): | |||
| vertex_type_column: int | |||
| relation_type_index: int | |||
| edges: torch.Tensor | |||
| target_values: torch.Tensor | |||
| def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | |||
| @@ -133,7 +134,7 @@ class Model(torch.nn.Module): | |||
| List[torch.Tensor]: | |||
| cur_layer_repr = in_layer_repr | |||
| for i in range(len(self.layer_dimensions) - 1): | |||
| next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | |||
| @@ -147,7 +148,7 @@ class Model(torch.nn.Module): | |||
| if self.keep_prob != 1: | |||
| x = dropout(x, self.keep_prob) | |||
| print('a, Layer:', i, 'x.shape:', x.shape) | |||
| # print('a, Layer:', i, 'x.shape:', x.shape) | |||
| x = _mm(x, conv_weights) | |||
| x = torch.split(x, | |||
| @@ -159,12 +160,12 @@ class Model(torch.nn.Module): | |||
| self.data.vertex_types[vt_row].count, | |||
| self.layer_dimensions[i + 1]) | |||
| print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
| # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
| x = x.sum(dim=0) | |||
| x = torch.nn.functional.normalize(x, p=2, dim=1) | |||
| # x = self.rel_activation(x) | |||
| print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
| # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
| next_layer_repr[vt_row].append(x) | |||
| @@ -174,6 +175,32 @@ class Model(torch.nn.Module): | |||
| cur_layer_repr = next_layer_repr | |||
| return next_layer_repr | |||
| def decode(self, last_layer_repr: List[torch.Tensor], | |||
| batch: TrainingBatch) -> torch.Tensor: | |||
| vt_row = batch.vertex_type_row | |||
| vt_col = batch.vertex_type_column | |||
| rel_idx = batch.relation_type_index | |||
| global_interaction = \ | |||
| self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||
| local_variation = \ | |||
| self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||
| in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||
| in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||
| x = torch.mm(in_row, local_variation) | |||
| x = torch.mm(x, global_interaction) | |||
| x = torch.mm(x, local_variation) | |||
| x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), | |||
| in_col.view(in_col.shape[0], in_col.shape[1], 1)) | |||
| x = torch.flatten(x) | |||
| x = self.dec_activation(x) | |||
| return x | |||
| def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||
| edges = [] | |||
| cur_edges = batch.edges | |||
| @@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||
| return torch.sparse.mm(a, b) | |||
| else: | |||
| return torch.mm(a, b) | |||
| def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||
| List[torch.Tensor]: | |||
| tot = sum(vertex_type_counts) | |||
| # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) | |||
| # print('indices.shape:', indices.shape) | |||
| ofs = 0 | |||
| res = [] | |||
| for cnt in vertex_type_counts: | |||
| ind = torch.cat([ | |||
| torch.arange(cnt).view(1, -1), | |||
| torch.arange(ofs, ofs+cnt).view(1, -1) | |||
| ]) | |||
| val = torch.ones(cnt) | |||
| x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) | |||
| res.append(x) | |||
| ofs += cnt | |||
| return res | |||
| @@ -1,9 +1,46 @@ | |||
| import torch | |||
| from triacontagon.model import Model | |||
| from triacontagon.model import Model, \ | |||
| TrainingBatch, \ | |||
| _per_layer_required_vertices | |||
| from triacontagon.data import Data | |||
| from triacontagon.decode import dedicom_decoder | |||
| def test_per_layer_required_vertices_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 4) | |||
| d.add_vertex_type('Drug', 5) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
| [1, 0, 0, 1], | |||
| [0, 1, 1, 0], | |||
| [0, 0, 1, 0], | |||
| [0, 1, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
| [0, 1, 0, 0, 1], | |||
| [0, 0, 1, 0, 0], | |||
| [1, 0, 0, 0, 1], | |||
| [0, 0, 1, 1, 0] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
| [1, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
| [0, 1] | |||
| ])) | |||
| res = _per_layer_required_vertices(d, batch, 5) | |||
| print('res:', res) | |||
| def test_model_convolve_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 4) | |||
| @@ -2,7 +2,7 @@ from triacontagon.util import \ | |||
| _clear_adjacency_matrix_except_rows, \ | |||
| _sparse_diag_cat, \ | |||
| _equal, \ | |||
| _per_layer_required_vertices | |||
| common_one_hot_encoding | |||
| from triacontagon.model import TrainingBatch | |||
| from triacontagon.decode import dedicom_decoder | |||
| from triacontagon.data import Data | |||
| @@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||
| assert _equal(res, truth).all() | |||
| def test_per_layer_required_vertices_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 4) | |||
| d.add_vertex_type('Drug', 5) | |||
| def test_common_one_hot_encoding_01(): | |||
| in_repr = common_one_hot_encoding([2000, 200]) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
| [1, 0, 0, 1], | |||
| [0, 1, 1, 0], | |||
| [0, 0, 1, 0], | |||
| [0, 1, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
| [0, 1, 0, 0, 1], | |||
| [0, 0, 1, 0, 0], | |||
| [1, 0, 0, 0, 1], | |||
| [0, 0, 1, 1, 0] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
| [1, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
| [0, 1] | |||
| ])) | |||
| res = _per_layer_required_vertices(d, batch, 5) | |||
| print('res:', res) | |||
| ref = torch.eye(2200) | |||
| assert torch.all(in_repr[0].to_dense() == ref[:2000, :]) | |||
| assert torch.all(in_repr[1].to_dense() == ref[2000:, :]) | |||