@@ -10,6 +10,7 @@ from typing import Callable, \ | |||
List | |||
import types | |||
from .util import _nonzero_sum | |||
import torch | |||
@dataclass | |||
@@ -66,6 +67,6 @@ class Data(object): | |||
if (vertex_type_row, vertex_type_column) in self.edge_types: | |||
raise KeyError('Edge type for given combination of row and column already exists') | |||
total_connectivity = _nonzero_sum(adjacency_matrices) | |||
self.edges_types[vertex_type_row, vertex_type_column] = \ | |||
VertexType(name, vertex_type_row, vertex_type_column, | |||
self.edge_types[vertex_type_row, vertex_type_column] = \ | |||
EdgeType(name, vertex_type_row, vertex_type_column, | |||
adjacency_matrices, decoder_factory, total_connectivity) |
@@ -11,7 +11,7 @@ from typing import Tuple, \ | |||
List | |||
def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||
def dedicom_decoder(input_dim: int, num_relation_types: int) -> \ | |||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||
global_interaction = init_glorot(input_dim, input_dim) | |||
@@ -22,18 +22,18 @@ def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||
return (global_interaction, local_variation) | |||
def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||
def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \ | |||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||
global_interaction = torch.eye(input_dim, input_dim) | |||
local_variation = [ | |||
torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||
torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||
for _ in range(num_relation_types) | |||
] | |||
return (global_interaction, local_variation) | |||
def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||
def bilinear_decoder(input_dim: int, num_relation_types: int) -> \ | |||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||
global_interaction = torch.eye(input_dim, input_dim) | |||
@@ -44,7 +44,7 @@ def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||
return (global_interaction, local_variation) | |||
def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||
def inner_product_decoder(input_dim: int, num_relation_types: int) -> \ | |||
Tuple[torch.Tensor, List[torch.Tensor]]: | |||
global_interaction = torch.eye(input_dim, input_dim) | |||
@@ -6,7 +6,8 @@ from .weights import init_glorot | |||
import types | |||
from typing import List, \ | |||
Dict, \ | |||
Callable | |||
Callable, \ | |||
Tuple | |||
from .util import _sparse_coo_tensor | |||
@@ -18,6 +19,46 @@ class TrainingBatch(object): | |||
edges: torch.Tensor | |||
def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||
num_layers: int) -> List[List[EdgeType]]: | |||
Q = [ | |||
( batch.vertex_type_row, batch.edges[:, 0] ), | |||
( batch.vertex_type_column, batch.edges[:, 1] ) | |||
] | |||
print('Q:', Q) | |||
res = [] | |||
for _ in range(num_layers): | |||
R = [] | |||
required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||
for vertex_type, vertices in Q: | |||
for et in data.edge_types.values(): | |||
if et.vertex_type_row == vertex_type: | |||
required_rows[vertex_type].append(vertices) | |||
indices = et.total_connectivity.indices() | |||
mask = torch.zeros(et.total_connectivity.shape[0]) | |||
mask[vertices] = 1 | |||
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||
R.append((et.vertex_type_column, | |||
indices[1, mask])) | |||
else: | |||
pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||
required_rows = [ torch.unique(torch.cat(x)) \ | |||
if len(x) > 0 \ | |||
else None \ | |||
for x in required_rows ] | |||
res.append(required_rows) | |||
Q = R | |||
return res | |||
class Model(torch.nn.Module): | |||
def __init__(self, data: Data, layer_dimensions: List[int], | |||
keep_prob: float, | |||
@@ -68,17 +109,16 @@ class Model(torch.nn.Module): | |||
torch.nn.Parameter(local_variation) | |||
]) | |||
def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||
rows: torch.Tensor) -> torch.Tensor: | |||
adj_mat = adjacency_matrix.coalesce() | |||
adj_mat = torch.index_select(adj_mat, 0, rows) | |||
adj_mat = adj_mat.coalesce() | |||
indices = adj_mat.indices() | |||
indices[0] = rows | |||
def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||
edges = [] | |||
cur_edges = batch.edges | |||
for _ in range(len(self.layer_dimensions) - 1): | |||
edges.append(cur_edges) | |||
key = (batch.vertex_type_row, batch.vertex_type_column) | |||
tot_conn = self.data.relation_types[key].total_connectivity | |||
cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) | |||
adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||
def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | |||
batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | |||
@@ -90,12 +130,13 @@ class Model(torch.nn.Module): | |||
columns = torch.nonzero(columns) | |||
for i in range(len(self.layer_dimensions) - 1): | |||
pass # columns = | |||
# TODO: finish | |||
columns = | |||
return None | |||
def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||
Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||
def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||
col = batch.vertex_type_column | |||
batch.edges[:, 1] | |||
@@ -41,7 +41,9 @@ def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||
indices = res.indices() | |||
res = _sparse_coo_tensor(indices, | |||
torch.ones(indices.shape[1], dtype=torch.uint8)) | |||
torch.ones(indices.shape[1], dtype=torch.uint8), | |||
adjacency_matrices[0].shape) | |||
res = res.coalesce() | |||
return res | |||
@@ -0,0 +1,40 @@ | |||
from triacontagon.model import _per_layer_required_rows, \ | |||
TrainingBatch | |||
from triacontagon.decode import dedicom_decoder | |||
from triacontagon.data import Data | |||
import torch | |||
def test_per_layer_required_rows_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 4) | |||
d.add_vertex_type('Drug', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
[0, 1, 0, 0, 1], | |||
[0, 0, 1, 0, 0], | |||
[1, 0, 0, 0, 1], | |||
[0, 0, 1, 1, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
[0, 1] | |||
])) | |||
res = _per_layer_required_rows(d, batch, 5) | |||
print('res:', res) |