| @@ -19,46 +19,6 @@ class TrainingBatch(object): | |||
| edges: torch.Tensor | |||
| def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||
| num_layers: int) -> List[List[EdgeType]]: | |||
| Q = [ | |||
| ( batch.vertex_type_row, batch.edges[:, 0] ), | |||
| ( batch.vertex_type_column, batch.edges[:, 1] ) | |||
| ] | |||
| print('Q:', Q) | |||
| res = [] | |||
| for _ in range(num_layers): | |||
| R = [] | |||
| required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
| for vertex_type, vertices in Q: | |||
| for et in data.edge_types.values(): | |||
| if et.vertex_type_row == vertex_type: | |||
| required_rows[vertex_type].append(vertices) | |||
| indices = et.total_connectivity.indices() | |||
| mask = torch.zeros(et.total_connectivity.shape[0]) | |||
| mask[vertices] = 1 | |||
| mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
| R.append((et.vertex_type_column, | |||
| indices[1, mask])) | |||
| else: | |||
| pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
| required_rows = [ torch.unique(torch.cat(x)) \ | |||
| if len(x) > 0 \ | |||
| else None \ | |||
| for x in required_rows ] | |||
| res.append(required_rows) | |||
| Q = R | |||
| return res | |||
| class Model(torch.nn.Module): | |||
| def __init__(self, data: Data, layer_dimensions: List[int], | |||
| keep_prob: float, | |||
| @@ -190,3 +190,41 @@ def _cat(matrices: List[torch.Tensor]): | |||
| res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
| return res | |||
| def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | |||
| num_layers: int) -> List[List[EdgeType]]: | |||
| Q = [ | |||
| ( batch.vertex_type_row, batch.edges[:, 0] ), | |||
| ( batch.vertex_type_column, batch.edges[:, 1] ) | |||
| ] | |||
| print('Q:', Q) | |||
| res = [] | |||
| for _ in range(num_layers): | |||
| R = [] | |||
| required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
| for vertex_type, vertices in Q: | |||
| for et in data.edge_types.values(): | |||
| if et.vertex_type_row == vertex_type: | |||
| required_rows[vertex_type].append(vertices) | |||
| indices = et.total_connectivity.indices() | |||
| mask = torch.zeros(et.total_connectivity.shape[0]) | |||
| mask[vertices] = 1 | |||
| mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
| R.append((et.vertex_type_column, | |||
| indices[1, mask])) | |||
| else: | |||
| pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
| required_rows = [ torch.unique(torch.cat(x)) \ | |||
| if len(x) > 0 \ | |||
| else None \ | |||
| for x in required_rows ] | |||
| res.append(required_rows) | |||
| Q = R | |||
| return res | |||
| @@ -1,7 +1,11 @@ | |||
| from triacontagon.util import \ | |||
| _clear_adjacency_matrix_except_rows, \ | |||
| _sparse_diag_cat, \ | |||
| _equal | |||
| _equal, \ | |||
| _per_layer_required_vertices | |||
| from triacontagon.model import TrainingBatch | |||
| from triacontagon.decode import dedicom_decoder | |||
| from triacontagon.data import Data | |||
| import torch | |||
| import time | |||
| @@ -121,3 +125,38 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||
| truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
| assert _equal(res, truth).all() | |||
| def test_per_layer_required_vertices_01(): | |||
| d = Data() | |||
| d.add_vertex_type('Gene', 4) | |||
| d.add_vertex_type('Drug', 5) | |||
| d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
| [1, 0, 0, 1], | |||
| [0, 1, 1, 0], | |||
| [0, 0, 1, 0], | |||
| [0, 1, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
| [0, 1, 0, 0, 1], | |||
| [0, 0, 1, 0, 0], | |||
| [1, 0, 0, 0, 1], | |||
| [0, 0, 1, 1, 0] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
| [1, 0, 0, 0, 0], | |||
| [0, 1, 0, 0, 0], | |||
| [0, 0, 1, 0, 0], | |||
| [0, 0, 0, 1, 0], | |||
| [0, 0, 0, 0, 1] | |||
| ]).to_sparse() ], dedicom_decoder) | |||
| batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
| [0, 1] | |||
| ])) | |||
| res = _per_layer_required_vertices(d, batch, 5) | |||
| print('res:', res) | |||