| @@ -10,6 +10,7 @@ from typing import Callable, \ | |||
| List | |||
| import types | |||
| from .util import _nonzero_sum | |||
| import torch | |||
| @dataclass | |||
| @@ -66,6 +67,6 @@ class Data(object): | |||
| if (vertex_type_row, vertex_type_column) in self.edge_types: | |||
| raise KeyError('Edge type for given combination of row and column already exists') | |||
| total_connectivity = _nonzero_sum(adjacency_matrices) | |||
| self.edges_types[vertex_type_row, vertex_type_column] = \ | |||
| VertexType(name, vertex_type_row, vertex_type_column, | |||
| self.edge_types[vertex_type_row, vertex_type_column] = \ | |||
| EdgeType(name, vertex_type_row, vertex_type_column, | |||
| adjacency_matrices, decoder_factory, total_connectivity) | |||
| @@ -11,7 +11,7 @@ from typing import Tuple, \ | |||
| List | |||
| def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||
| def dedicom_decoder(input_dim: int, num_relation_types: int) -> \ | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| global_interaction = init_glorot(input_dim, input_dim) | |||
| @@ -22,18 +22,18 @@ def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||
| return (global_interaction, local_variation) | |||
| def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||
| def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \ | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| local_variation = [ | |||
| torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||
| torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||
| for _ in range(num_relation_types) | |||
| ] | |||
| return (global_interaction, local_variation) | |||
| def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||
| def bilinear_decoder(input_dim: int, num_relation_types: int) -> \ | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| @@ -44,7 +44,7 @@ def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||
| return (global_interaction, local_variation) | |||
| def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||
| def inner_product_decoder(input_dim: int, num_relation_types: int) -> \ | |||
| Tuple[torch.Tensor, List[torch.Tensor]]: | |||
| global_interaction = torch.eye(input_dim, input_dim) | |||
| @@ -6,7 +6,8 @@ from .weights import init_glorot | |||
| import types | |||
| from typing import List, \ | |||
| Dict, \ | |||
| Callable | |||
| Callable, \ | |||
| Tuple | |||
| from .util import _sparse_coo_tensor | |||
| @@ -18,6 +19,46 @@ class TrainingBatch(object): | |||
| edges: torch.Tensor | |||
| def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||
| num_layers: int) -> List[List[EdgeType]]: | |||
| Q = [ | |||
| ( batch.vertex_type_row, batch.edges[:, 0] ), | |||
| ( batch.vertex_type_column, batch.edges[:, 1] ) | |||
| ] | |||
| print('Q:', Q) | |||
| res = [] | |||
| for _ in range(num_layers): | |||
| R = [] | |||
| required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
| for vertex_type, vertices in Q: | |||
| for et in data.edge_types.values(): | |||
| if et.vertex_type_row == vertex_type: | |||
| required_rows[vertex_type].append(vertices) | |||
| indices = et.total_connectivity.indices() | |||
| mask = torch.zeros(et.total_connectivity.shape[0]) | |||
| mask[vertices] = 1 | |||
| mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
| R.append((et.vertex_type_column, | |||
| indices[1, mask])) | |||
| else: | |||
| pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
| required_rows = [ torch.unique(torch.cat(x)) \ | |||
| if len(x) > 0 \ | |||
| else None \ | |||
| for x in required_rows ] | |||
| res.append(required_rows) | |||
| Q = R | |||
| return res | |||
| class Model(torch.nn.Module): | |||
| def __init__(self, data: Data, layer_dimensions: List[int], | |||
| keep_prob: float, | |||
| @@ -68,17 +109,16 @@ class Model(torch.nn.Module): | |||
| torch.nn.Parameter(local_variation) | |||
| ]) | |||
| def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||
| rows: torch.Tensor) -> torch.Tensor: | |||
| adj_mat = adjacency_matrix.coalesce() | |||
| adj_mat = torch.index_select(adj_mat, 0, rows) | |||
| adj_mat = adj_mat.coalesce() | |||
| indices = adj_mat.indices() | |||
| indices[0] = rows | |||
| def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||
| edges = [] | |||
| cur_edges = batch.edges | |||
| for _ in range(len(self.layer_dimensions) - 1): | |||
| edges.append(cur_edges) | |||
| key = (batch.vertex_type_row, batch.vertex_type_column) | |||
| tot_conn = self.data.relation_types[key].total_connectivity | |||
| cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) | |||
| adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||
| def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | |||
| batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | |||
| @@ -90,12 +130,13 @@ class Model(torch.nn.Module): | |||
| columns = torch.nonzero(columns) | |||
| for i in range(len(self.layer_dimensions) - 1): | |||
| pass # columns = | |||
| # TODO: finish | |||
| columns = | |||
| return None | |||
| def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||
| Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||
| def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||
| col = batch.vertex_type_column | |||
| batch.edges[:, 1] | |||
| @@ -41,7 +41,9 @@ def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||
| indices = res.indices() | |||
| res = _sparse_coo_tensor(indices, | |||
| torch.ones(indices.shape[1], dtype=torch.uint8)) | |||
| torch.ones(indices.shape[1], dtype=torch.uint8), | |||
| adjacency_matrices[0].shape) | |||
| res = res.coalesce() | |||
| return res | |||
| @@ -0,0 +1,40 @@ | |||
| from triacontagon.model import _per_layer_required_rows, \ | |||
| TrainingBatch | |||
| from triacontagon.decode import dedicom_decoder | |||
| from triacontagon.data import Data | |||
| import torch | |||
| def test_per_layer_required_rows_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 4) | |||
| d.add_vertex_type('Drug', 5) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
| [1, 0, 0, 1], | |||
| [0, 1, 1, 0], | |||
| [0, 0, 1, 0], | |||
| [0, 1, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
| [0, 1, 0, 0, 1], | |||
| [0, 0, 1, 0, 0], | |||
| [1, 0, 0, 0, 1], | |||
| [0, 0, 1, 1, 0] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
| [1, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
| [0, 1] | |||
| ])) | |||
| res = _per_layer_required_rows(d, batch, 5) | |||
| print('res:', res) | |||