| @@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob): | |||||
| def dropout(x, keep_prob): | def dropout(x, keep_prob): | ||||
| if keep_prob == 1: | |||||
| return x | |||||
| if x.is_sparse: | if x.is_sparse: | ||||
| return dropout_sparse(x, keep_prob) | return dropout_sparse(x, keep_prob) | ||||
| else: | else: | ||||
| @@ -22,6 +22,7 @@ class TrainingBatch(object): | |||||
| vertex_type_column: int | vertex_type_column: int | ||||
| relation_type_index: int | relation_type_index: int | ||||
| edges: torch.Tensor | edges: torch.Tensor | ||||
| target_values: torch.Tensor | |||||
| def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | ||||
| @@ -133,7 +134,7 @@ class Model(torch.nn.Module): | |||||
| List[torch.Tensor]: | List[torch.Tensor]: | ||||
| cur_layer_repr = in_layer_repr | cur_layer_repr = in_layer_repr | ||||
| for i in range(len(self.layer_dimensions) - 1): | for i in range(len(self.layer_dimensions) - 1): | ||||
| next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | ||||
| @@ -147,7 +148,7 @@ class Model(torch.nn.Module): | |||||
| if self.keep_prob != 1: | if self.keep_prob != 1: | ||||
| x = dropout(x, self.keep_prob) | x = dropout(x, self.keep_prob) | ||||
| print('a, Layer:', i, 'x.shape:', x.shape) | |||||
| # print('a, Layer:', i, 'x.shape:', x.shape) | |||||
| x = _mm(x, conv_weights) | x = _mm(x, conv_weights) | ||||
| x = torch.split(x, | x = torch.split(x, | ||||
| @@ -159,12 +160,12 @@ class Model(torch.nn.Module): | |||||
| self.data.vertex_types[vt_row].count, | self.data.vertex_types[vt_row].count, | ||||
| self.layer_dimensions[i + 1]) | self.layer_dimensions[i + 1]) | ||||
| print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
| # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
| x = x.sum(dim=0) | x = x.sum(dim=0) | ||||
| x = torch.nn.functional.normalize(x, p=2, dim=1) | x = torch.nn.functional.normalize(x, p=2, dim=1) | ||||
| # x = self.rel_activation(x) | # x = self.rel_activation(x) | ||||
| print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
| # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
| next_layer_repr[vt_row].append(x) | next_layer_repr[vt_row].append(x) | ||||
| @@ -174,6 +175,32 @@ class Model(torch.nn.Module): | |||||
| cur_layer_repr = next_layer_repr | cur_layer_repr = next_layer_repr | ||||
| return next_layer_repr | return next_layer_repr | ||||
| def decode(self, last_layer_repr: List[torch.Tensor], | |||||
| batch: TrainingBatch) -> torch.Tensor: | |||||
| vt_row = batch.vertex_type_row | |||||
| vt_col = batch.vertex_type_column | |||||
| rel_idx = batch.relation_type_index | |||||
| global_interaction = \ | |||||
| self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||||
| local_variation = \ | |||||
| self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||||
| in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||||
| in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||||
| x = torch.mm(in_row, local_variation) | |||||
| x = torch.mm(x, global_interaction) | |||||
| x = torch.mm(x, local_variation) | |||||
| x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), | |||||
| in_col.view(in_col.shape[0], in_col.shape[1], 1)) | |||||
| x = torch.flatten(x) | |||||
| x = self.dec_activation(x) | |||||
| return x | |||||
| def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | ||||
| edges = [] | edges = [] | ||||
| cur_edges = batch.edges | cur_edges = batch.edges | ||||
| @@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||||
| return torch.sparse.mm(a, b) | return torch.sparse.mm(a, b) | ||||
| else: | else: | ||||
| return torch.mm(a, b) | return torch.mm(a, b) | ||||
| def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||||
| List[torch.Tensor]: | |||||
| tot = sum(vertex_type_counts) | |||||
| # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) | |||||
| # print('indices.shape:', indices.shape) | |||||
| ofs = 0 | |||||
| res = [] | |||||
| for cnt in vertex_type_counts: | |||||
| ind = torch.cat([ | |||||
| torch.arange(cnt).view(1, -1), | |||||
| torch.arange(ofs, ofs+cnt).view(1, -1) | |||||
| ]) | |||||
| val = torch.ones(cnt) | |||||
| x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) | |||||
| res.append(x) | |||||
| ofs += cnt | |||||
| return res | |||||
| @@ -1,9 +1,46 @@ | |||||
| import torch | import torch | ||||
| from triacontagon.model import Model | |||||
| from triacontagon.model import Model, \ | |||||
| TrainingBatch, \ | |||||
| _per_layer_required_vertices | |||||
| from triacontagon.data import Data | from triacontagon.data import Data | ||||
| from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
| def test_per_layer_required_vertices_01(): | |||||
| d = Data() | |||||
| d.add_vertex_type('Gene', 4) | |||||
| d.add_vertex_type('Drug', 5) | |||||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
| [1, 0, 0, 1], | |||||
| [0, 1, 1, 0], | |||||
| [0, 0, 1, 0], | |||||
| [0, 1, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
| [0, 1, 0, 0, 1], | |||||
| [0, 0, 1, 0, 0], | |||||
| [1, 0, 0, 0, 1], | |||||
| [0, 0, 1, 1, 0] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
| [1, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 0, 1, 0, 0], | |||||
| [0, 0, 0, 1, 0], | |||||
| [0, 0, 0, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
| [0, 1] | |||||
| ])) | |||||
| res = _per_layer_required_vertices(d, batch, 5) | |||||
| print('res:', res) | |||||
| def test_model_convolve_01(): | def test_model_convolve_01(): | ||||
| d = Data() | d = Data() | ||||
| d.add_vertex_type('Gene', 4) | d.add_vertex_type('Gene', 4) | ||||
| @@ -2,7 +2,7 @@ from triacontagon.util import \ | |||||
| _clear_adjacency_matrix_except_rows, \ | _clear_adjacency_matrix_except_rows, \ | ||||
| _sparse_diag_cat, \ | _sparse_diag_cat, \ | ||||
| _equal, \ | _equal, \ | ||||
| _per_layer_required_vertices | |||||
| common_one_hot_encoding | |||||
| from triacontagon.model import TrainingBatch | from triacontagon.model import TrainingBatch | ||||
| from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
| from triacontagon.data import Data | from triacontagon.data import Data | ||||
| @@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||||
| assert _equal(res, truth).all() | assert _equal(res, truth).all() | ||||
| def test_per_layer_required_vertices_01(): | |||||
| d = Data() | |||||
| d.add_vertex_type('Gene', 4) | |||||
| d.add_vertex_type('Drug', 5) | |||||
| def test_common_one_hot_encoding_01(): | |||||
| in_repr = common_one_hot_encoding([2000, 200]) | |||||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
| [1, 0, 0, 1], | |||||
| [0, 1, 1, 0], | |||||
| [0, 0, 1, 0], | |||||
| [0, 1, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
| [0, 1, 0, 0, 1], | |||||
| [0, 0, 1, 0, 0], | |||||
| [1, 0, 0, 0, 1], | |||||
| [0, 0, 1, 1, 0] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
| [1, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 0, 1, 0, 0], | |||||
| [0, 0, 0, 1, 0], | |||||
| [0, 0, 0, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
| [0, 1] | |||||
| ])) | |||||
| res = _per_layer_required_vertices(d, batch, 5) | |||||
| print('res:', res) | |||||
| ref = torch.eye(2200) | |||||
| assert torch.all(in_repr[0].to_dense() == ref[:2000, :]) | |||||
| assert torch.all(in_repr[1].to_dense() == ref[2000:, :]) | |||||