| @@ -10,6 +10,7 @@ from typing import Callable, \ | |||||
| List | List | ||||
| import types | import types | ||||
| from .util import _nonzero_sum | from .util import _nonzero_sum | ||||
| import torch | |||||
| @dataclass | @dataclass | ||||
| @@ -66,6 +67,6 @@ class Data(object): | |||||
| if (vertex_type_row, vertex_type_column) in self.edge_types: | if (vertex_type_row, vertex_type_column) in self.edge_types: | ||||
| raise KeyError('Edge type for given combination of row and column already exists') | raise KeyError('Edge type for given combination of row and column already exists') | ||||
| total_connectivity = _nonzero_sum(adjacency_matrices) | total_connectivity = _nonzero_sum(adjacency_matrices) | ||||
| self.edges_types[vertex_type_row, vertex_type_column] = \ | |||||
| VertexType(name, vertex_type_row, vertex_type_column, | |||||
| self.edge_types[vertex_type_row, vertex_type_column] = \ | |||||
| EdgeType(name, vertex_type_row, vertex_type_column, | |||||
| adjacency_matrices, decoder_factory, total_connectivity) | adjacency_matrices, decoder_factory, total_connectivity) | ||||
| @@ -11,7 +11,7 @@ from typing import Tuple, \ | |||||
| List | List | ||||
| def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||||
| def dedicom_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
| Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
| global_interaction = init_glorot(input_dim, input_dim) | global_interaction = init_glorot(input_dim, input_dim) | ||||
| @@ -22,18 +22,18 @@ def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||||
| return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
| def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||||
| def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
| Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
| global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
| local_variation = [ | local_variation = [ | ||||
| torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||||
| torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
| for _ in range(num_relation_types) | for _ in range(num_relation_types) | ||||
| ] | ] | ||||
| return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
| def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||||
| def bilinear_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
| Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
| global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
| @@ -44,7 +44,7 @@ def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||||
| return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
| def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||||
| def inner_product_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
| Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
| global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
| @@ -6,7 +6,8 @@ from .weights import init_glorot | |||||
| import types | import types | ||||
| from typing import List, \ | from typing import List, \ | ||||
| Dict, \ | Dict, \ | ||||
| Callable | |||||
| Callable, \ | |||||
| Tuple | |||||
| from .util import _sparse_coo_tensor | from .util import _sparse_coo_tensor | ||||
| @@ -18,6 +19,46 @@ class TrainingBatch(object): | |||||
| edges: torch.Tensor | edges: torch.Tensor | ||||
| def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||||
| num_layers: int) -> List[List[EdgeType]]: | |||||
| Q = [ | |||||
| ( batch.vertex_type_row, batch.edges[:, 0] ), | |||||
| ( batch.vertex_type_column, batch.edges[:, 1] ) | |||||
| ] | |||||
| print('Q:', Q) | |||||
| res = [] | |||||
| for _ in range(num_layers): | |||||
| R = [] | |||||
| required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||||
| for vertex_type, vertices in Q: | |||||
| for et in data.edge_types.values(): | |||||
| if et.vertex_type_row == vertex_type: | |||||
| required_rows[vertex_type].append(vertices) | |||||
| indices = et.total_connectivity.indices() | |||||
| mask = torch.zeros(et.total_connectivity.shape[0]) | |||||
| mask[vertices] = 1 | |||||
| mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||||
| R.append((et.vertex_type_column, | |||||
| indices[1, mask])) | |||||
| else: | |||||
| pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||||
| required_rows = [ torch.unique(torch.cat(x)) \ | |||||
| if len(x) > 0 \ | |||||
| else None \ | |||||
| for x in required_rows ] | |||||
| res.append(required_rows) | |||||
| Q = R | |||||
| return res | |||||
| class Model(torch.nn.Module): | class Model(torch.nn.Module): | ||||
| def __init__(self, data: Data, layer_dimensions: List[int], | def __init__(self, data: Data, layer_dimensions: List[int], | ||||
| keep_prob: float, | keep_prob: float, | ||||
| @@ -68,17 +109,16 @@ class Model(torch.nn.Module): | |||||
| torch.nn.Parameter(local_variation) | torch.nn.Parameter(local_variation) | ||||
| ]) | ]) | ||||
| def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||||
| rows: torch.Tensor) -> torch.Tensor: | |||||
| adj_mat = adjacency_matrix.coalesce() | |||||
| adj_mat = torch.index_select(adj_mat, 0, rows) | |||||
| adj_mat = adj_mat.coalesce() | |||||
| indices = adj_mat.indices() | |||||
| indices[0] = rows | |||||
| def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||||
| edges = [] | |||||
| cur_edges = batch.edges | |||||
| for _ in range(len(self.layer_dimensions) - 1): | |||||
| edges.append(cur_edges) | |||||
| key = (batch.vertex_type_row, batch.vertex_type_column) | |||||
| tot_conn = self.data.relation_types[key].total_connectivity | |||||
| cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) | |||||
| adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||||
| def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | ||||
| batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | ||||
| @@ -90,12 +130,13 @@ class Model(torch.nn.Module): | |||||
| columns = torch.nonzero(columns) | columns = torch.nonzero(columns) | ||||
| for i in range(len(self.layer_dimensions) - 1): | for i in range(len(self.layer_dimensions) - 1): | ||||
| pass # columns = | |||||
| # TODO: finish | |||||
| columns = | |||||
| return None | |||||
| def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||||
| Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||||
| def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||||
| col = batch.vertex_type_column | col = batch.vertex_type_column | ||||
| batch.edges[:, 1] | batch.edges[:, 1] | ||||
| @@ -41,7 +41,9 @@ def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||||
| indices = res.indices() | indices = res.indices() | ||||
| res = _sparse_coo_tensor(indices, | res = _sparse_coo_tensor(indices, | ||||
| torch.ones(indices.shape[1], dtype=torch.uint8)) | |||||
| torch.ones(indices.shape[1], dtype=torch.uint8), | |||||
| adjacency_matrices[0].shape) | |||||
| res = res.coalesce() | |||||
| return res | return res | ||||
| @@ -0,0 +1,40 @@ | |||||
| from triacontagon.model import _per_layer_required_rows, \ | |||||
| TrainingBatch | |||||
| from triacontagon.decode import dedicom_decoder | |||||
| from triacontagon.data import Data | |||||
| import torch | |||||
| def test_per_layer_required_rows_01(): | |||||
| d = Data() | |||||
| d.add_vertex_type('Gene', 4) | |||||
| d.add_vertex_type('Drug', 5) | |||||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
| [1, 0, 0, 1], | |||||
| [0, 1, 1, 0], | |||||
| [0, 0, 1, 0], | |||||
| [0, 1, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
| [0, 1, 0, 0, 1], | |||||
| [0, 0, 1, 0, 0], | |||||
| [1, 0, 0, 0, 1], | |||||
| [0, 0, 1, 1, 0] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
| [1, 0, 0, 0, 0], | |||||
| [0, 1, 0, 0, 0], | |||||
| [0, 0, 1, 0, 0], | |||||
| [0, 0, 0, 1, 0], | |||||
| [0, 0, 0, 0, 1] | |||||
| ]).to_sparse() ], dedicom_decoder) | |||||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
| [0, 1] | |||||
| ])) | |||||
| res = _per_layer_required_rows(d, batch, 5) | |||||
| print('res:', res) | |||||