@@ -10,6 +10,7 @@ from typing import Callable, \ | |||||
List | List | ||||
import types | import types | ||||
from .util import _nonzero_sum | from .util import _nonzero_sum | ||||
import torch | |||||
@dataclass | @dataclass | ||||
@@ -66,6 +67,6 @@ class Data(object): | |||||
if (vertex_type_row, vertex_type_column) in self.edge_types: | if (vertex_type_row, vertex_type_column) in self.edge_types: | ||||
raise KeyError('Edge type for given combination of row and column already exists') | raise KeyError('Edge type for given combination of row and column already exists') | ||||
total_connectivity = _nonzero_sum(adjacency_matrices) | total_connectivity = _nonzero_sum(adjacency_matrices) | ||||
self.edges_types[vertex_type_row, vertex_type_column] = \ | |||||
VertexType(name, vertex_type_row, vertex_type_column, | |||||
self.edge_types[vertex_type_row, vertex_type_column] = \ | |||||
EdgeType(name, vertex_type_row, vertex_type_column, | |||||
adjacency_matrices, decoder_factory, total_connectivity) | adjacency_matrices, decoder_factory, total_connectivity) |
@@ -11,7 +11,7 @@ from typing import Tuple, \ | |||||
List | List | ||||
def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||||
def dedicom_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
global_interaction = init_glorot(input_dim, input_dim) | global_interaction = init_glorot(input_dim, input_dim) | ||||
@@ -22,18 +22,18 @@ def dedicom_decoder(input_dim: int, num_relation_types: int) -> | |||||
return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
def dist_mult_decoder(input_dim: int, num_relation_types: int) -> | |||||
def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
local_variation = [ | local_variation = [ | ||||
torch.diag(torch.flatten(init_glorot(input_dim, 1)))) \ | |||||
torch.diag(torch.flatten(init_glorot(input_dim, 1))) \ | |||||
for _ in range(num_relation_types) | for _ in range(num_relation_types) | ||||
] | ] | ||||
return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||||
def bilinear_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
@@ -44,7 +44,7 @@ def bilinear_decoder(input_dim: int, num_relation_types: int) -> | |||||
return (global_interaction, local_variation) | return (global_interaction, local_variation) | ||||
def inner_product_decoder(input_dim: int, num_relation_types: int) -> | |||||
def inner_product_decoder(input_dim: int, num_relation_types: int) -> \ | |||||
Tuple[torch.Tensor, List[torch.Tensor]]: | Tuple[torch.Tensor, List[torch.Tensor]]: | ||||
global_interaction = torch.eye(input_dim, input_dim) | global_interaction = torch.eye(input_dim, input_dim) | ||||
@@ -6,7 +6,8 @@ from .weights import init_glorot | |||||
import types | import types | ||||
from typing import List, \ | from typing import List, \ | ||||
Dict, \ | Dict, \ | ||||
Callable | |||||
Callable, \ | |||||
Tuple | |||||
from .util import _sparse_coo_tensor | from .util import _sparse_coo_tensor | ||||
@@ -18,6 +19,46 @@ class TrainingBatch(object): | |||||
edges: torch.Tensor | edges: torch.Tensor | ||||
def _per_layer_required_rows(data: Data, batch: TrainingBatch, | |||||
num_layers: int) -> List[List[EdgeType]]: | |||||
Q = [ | |||||
( batch.vertex_type_row, batch.edges[:, 0] ), | |||||
( batch.vertex_type_column, batch.edges[:, 1] ) | |||||
] | |||||
print('Q:', Q) | |||||
res = [] | |||||
for _ in range(num_layers): | |||||
R = [] | |||||
required_rows = [ [] for _ in range(len(data.vertex_types)) ] | |||||
for vertex_type, vertices in Q: | |||||
for et in data.edge_types.values(): | |||||
if et.vertex_type_row == vertex_type: | |||||
required_rows[vertex_type].append(vertices) | |||||
indices = et.total_connectivity.indices() | |||||
mask = torch.zeros(et.total_connectivity.shape[0]) | |||||
mask[vertices] = 1 | |||||
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0] | |||||
R.append((et.vertex_type_column, | |||||
indices[1, mask])) | |||||
else: | |||||
pass # required_rows[et.vertex_type_row].append(torch.zeros(0)) | |||||
required_rows = [ torch.unique(torch.cat(x)) \ | |||||
if len(x) > 0 \ | |||||
else None \ | |||||
for x in required_rows ] | |||||
res.append(required_rows) | |||||
Q = R | |||||
return res | |||||
class Model(torch.nn.Module): | class Model(torch.nn.Module): | ||||
def __init__(self, data: Data, layer_dimensions: List[int], | def __init__(self, data: Data, layer_dimensions: List[int], | ||||
keep_prob: float, | keep_prob: float, | ||||
@@ -68,17 +109,16 @@ class Model(torch.nn.Module): | |||||
torch.nn.Parameter(local_variation) | torch.nn.Parameter(local_variation) | ||||
]) | ]) | ||||
def limit_adjacency_matrix_to_rows(self, adjacency_matrix: torch.Tensor, | |||||
rows: torch.Tensor) -> torch.Tensor: | |||||
adj_mat = adjacency_matrix.coalesce() | |||||
adj_mat = torch.index_select(adj_mat, 0, rows) | |||||
adj_mat = adj_mat.coalesce() | |||||
indices = adj_mat.indices() | |||||
indices[0] = rows | |||||
def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||||
edges = [] | |||||
cur_edges = batch.edges | |||||
for _ in range(len(self.layer_dimensions) - 1): | |||||
edges.append(cur_edges) | |||||
key = (batch.vertex_type_row, batch.vertex_type_column) | |||||
tot_conn = self.data.relation_types[key].total_connectivity | |||||
cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1]) | |||||
adj_mat = _sparse_coo_tensor(indices, adj_mat.values(), adjacency_matrix.shape) | |||||
def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor, | ||||
batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor: | ||||
@@ -90,12 +130,13 @@ class Model(torch.nn.Module): | |||||
columns = torch.nonzero(columns) | columns = torch.nonzero(columns) | ||||
for i in range(len(self.layer_dimensions) - 1): | for i in range(len(self.layer_dimensions) - 1): | ||||
pass # columns = | |||||
# TODO: finish | |||||
columns = | |||||
return None | |||||
def temporary_adjacency_matrices(self, batch: TrainingBatch) -> | |||||
Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||||
def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]: | |||||
col = batch.vertex_type_column | col = batch.vertex_type_column | ||||
batch.edges[:, 1] | batch.edges[:, 1] | ||||
@@ -41,7 +41,9 @@ def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): | |||||
indices = res.indices() | indices = res.indices() | ||||
res = _sparse_coo_tensor(indices, | res = _sparse_coo_tensor(indices, | ||||
torch.ones(indices.shape[1], dtype=torch.uint8)) | |||||
torch.ones(indices.shape[1], dtype=torch.uint8), | |||||
adjacency_matrices[0].shape) | |||||
res = res.coalesce() | |||||
return res | return res | ||||
@@ -0,0 +1,40 @@ | |||||
from triacontagon.model import _per_layer_required_rows, \ | |||||
TrainingBatch | |||||
from triacontagon.decode import dedicom_decoder | |||||
from triacontagon.data import Data | |||||
import torch | |||||
def test_per_layer_required_rows_01(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 4) | |||||
d.add_vertex_type('Drug', 5) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
[1, 0, 0, 1], | |||||
[0, 1, 1, 0], | |||||
[0, 0, 1, 0], | |||||
[0, 1, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
[0, 1, 0, 0, 1], | |||||
[0, 0, 1, 0, 0], | |||||
[1, 0, 0, 0, 1], | |||||
[0, 0, 1, 1, 0] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
[1, 0, 0, 0, 0], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
[0, 1] | |||||
])) | |||||
res = _per_layer_required_rows(d, batch, 5) | |||||
print('res:', res) |