@@ -19,46 +19,6 @@ class TrainingBatch(object): | |||
edges: torch.Tensor | |||
def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||
num_layers: int) -> List[List[EdgeType]]: | |||
Q = [ | |||
( batch.vertex_type_row, batch.edges[:, 0] ), | |||
( batch.vertex_type_column, batch.edges[:, 1] ) | |||
] | |||
print('Q:', Q) | |||
res = [] | |||
for _ in range(num_layers): | |||
R = [] | |||
required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
for vertex_type, vertices in Q: | |||
for et in data.edge_types.values(): | |||
if et.vertex_type_row == vertex_type: | |||
required_rows[vertex_type].append(vertices) | |||
indices = et.total_connectivity.indices() | |||
mask = torch.zeros(et.total_connectivity.shape[0]) | |||
mask[vertices] = 1 | |||
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
R.append((et.vertex_type_column, | |||
indices[1, mask])) | |||
else: | |||
pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
required_rows = [ torch.unique(torch.cat(x)) \ | |||
if len(x) > 0 \ | |||
else None \ | |||
for x in required_rows ] | |||
res.append(required_rows) | |||
Q = R | |||
return res | |||
class Model(torch.nn.Module): | |||
def __init__(self, data: Data, layer_dimensions: List[int], | |||
keep_prob: float, | |||
@@ -190,3 +190,41 @@ def _cat(matrices: List[torch.Tensor]): | |||
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) | |||
return res | |||
def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | |||
num_layers: int) -> List[List[EdgeType]]: | |||
Q = [ | |||
( batch.vertex_type_row, batch.edges[:, 0] ), | |||
( batch.vertex_type_column, batch.edges[:, 1] ) | |||
] | |||
print('Q:', Q) | |||
res = [] | |||
for _ in range(num_layers): | |||
R = [] | |||
required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
for vertex_type, vertices in Q: | |||
for et in data.edge_types.values(): | |||
if et.vertex_type_row == vertex_type: | |||
required_rows[vertex_type].append(vertices) | |||
indices = et.total_connectivity.indices() | |||
mask = torch.zeros(et.total_connectivity.shape[0]) | |||
mask[vertices] = 1 | |||
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
R.append((et.vertex_type_column, | |||
indices[1, mask])) | |||
else: | |||
pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
required_rows = [ torch.unique(torch.cat(x)) \ | |||
if len(x) > 0 \ | |||
else None \ | |||
for x in required_rows ] | |||
res.append(required_rows) | |||
Q = R | |||
return res |
@@ -1,7 +1,11 @@ | |||
from triacontagon.util import \ | |||
_clear_adjacency_matrix_except_rows, \ | |||
_sparse_diag_cat, \ | |||
_equal | |||
_equal, \ | |||
_per_layer_required_vertices | |||
from triacontagon.model import TrainingBatch | |||
from triacontagon.decode import dedicom_decoder | |||
from triacontagon.data import Data | |||
import torch | |||
import time | |||
@@ -121,3 +125,38 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300) | |||
assert _equal(res, truth).all() | |||
def test_per_layer_required_vertices_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 4) | |||
d.add_vertex_type('Drug', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
[0, 1, 0, 0, 1], | |||
[0, 0, 1, 0, 0], | |||
[1, 0, 0, 0, 1], | |||
[0, 0, 1, 1, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
[0, 1] | |||
])) | |||
res = _per_layer_required_vertices(d, batch, 5) | |||
print('res:', res) |