IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Sfoglia il codice sorgente

Rename _per_layer_required_rows to _vertices.

master
Stanislaw Adaszewski 4 anni fa
parent
commit
b607a727ae
3 ha cambiato i file con 78 aggiunte e 41 eliminazioni
  1. +0
    -40
      src/triacontagon/model.py
  2. +38
    -0
      src/triacontagon/util.py
  3. +40
    -1
      tests/triacontagon/test_util.py

+ 0
- 40
src/triacontagon/model.py Vedi File

@@ -19,46 +19,6 @@ class TrainingBatch(object):
edges: torch.Tensor
def _per_layer_required_rows(data: Data, batch: TrainingBatch,
num_layers: int) -> List[List[EdgeType]]:
Q = [
( batch.vertex_type_row, batch.edges[:, 0] ),
( batch.vertex_type_column, batch.edges[:, 1] )
]
print('Q:', Q)
res = []
for _ in range(num_layers):
R = []
required_rows = [ [] for _ in range(len(data.vertex_types)) ]
for vertex_type, vertices in Q:
for et in data.edge_types.values():
if et.vertex_type_row == vertex_type:
required_rows[vertex_type].append(vertices)
indices = et.total_connectivity.indices()
mask = torch.zeros(et.total_connectivity.shape[0])
mask[vertices] = 1
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
R.append((et.vertex_type_column,
indices[1, mask]))
else:
pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
required_rows = [ torch.unique(torch.cat(x)) \
if len(x) > 0 \
else None \
for x in required_rows ]
res.append(required_rows)
Q = R
return res
class Model(torch.nn.Module):
def __init__(self, data: Data, layer_dimensions: List[int],
keep_prob: float,


+ 38
- 0
src/triacontagon/util.py Vedi File

@@ -190,3 +190,41 @@ def _cat(matrices: List[torch.Tensor]):
res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
return res
def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
num_layers: int) -> List[List[EdgeType]]:
Q = [
( batch.vertex_type_row, batch.edges[:, 0] ),
( batch.vertex_type_column, batch.edges[:, 1] )
]
print('Q:', Q)
res = []
for _ in range(num_layers):
R = []
required_rows = [ [] for _ in range(len(data.vertex_types)) ]
for vertex_type, vertices in Q:
for et in data.edge_types.values():
if et.vertex_type_row == vertex_type:
required_rows[vertex_type].append(vertices)
indices = et.total_connectivity.indices()
mask = torch.zeros(et.total_connectivity.shape[0])
mask[vertices] = 1
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
R.append((et.vertex_type_column,
indices[1, mask]))
else:
pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
required_rows = [ torch.unique(torch.cat(x)) \
if len(x) > 0 \
else None \
for x in required_rows ]
res.append(required_rows)
Q = R
return res

+ 40
- 1
tests/triacontagon/test_util.py Vedi File

@@ -1,7 +1,11 @@
from triacontagon.util import \
_clear_adjacency_matrix_except_rows, \
_sparse_diag_cat, \
_equal
_equal, \
_per_layer_required_vertices
from triacontagon.model import TrainingBatch
from triacontagon.decode import dedicom_decoder
from triacontagon.data import Data
import torch
import time
@@ -121,3 +125,38 @@ def test_clear_adjacency_matrix_except_rows_05():
truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
assert _equal(res, truth).all()
def test_per_layer_required_vertices_01():
d = Data()
d.add_vertex_type('Gene', 4)
d.add_vertex_type('Drug', 5)
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 1, 0],
[0, 1, 0, 1]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 1],
[0, 0, 1, 1, 0]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]).to_sparse() ], dedicom_decoder)
batch = TrainingBatch(0, 1, 0, torch.tensor([
[0, 1]
]))
res = _per_layer_required_vertices(d, batch, 5)
print('res:', res)

Loading…
Annulla
Salva