@@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob): | |||
def dropout(x, keep_prob): | |||
if keep_prob == 1: | |||
return x | |||
if x.is_sparse: | |||
return dropout_sparse(x, keep_prob) | |||
else: | |||
@@ -22,6 +22,7 @@ class TrainingBatch(object): | |||
vertex_type_column: int | |||
relation_type_index: int | |||
edges: torch.Tensor | |||
target_values: torch.Tensor | |||
def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | |||
@@ -133,7 +134,7 @@ class Model(torch.nn.Module): | |||
List[torch.Tensor]: | |||
cur_layer_repr = in_layer_repr | |||
for i in range(len(self.layer_dimensions) - 1): | |||
next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | |||
@@ -147,7 +148,7 @@ class Model(torch.nn.Module): | |||
if self.keep_prob != 1: | |||
x = dropout(x, self.keep_prob) | |||
print('a, Layer:', i, 'x.shape:', x.shape) | |||
# print('a, Layer:', i, 'x.shape:', x.shape) | |||
x = _mm(x, conv_weights) | |||
x = torch.split(x, | |||
@@ -159,12 +160,12 @@ class Model(torch.nn.Module): | |||
self.data.vertex_types[vt_row].count, | |||
self.layer_dimensions[i + 1]) | |||
print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
# print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
x = x.sum(dim=0) | |||
x = torch.nn.functional.normalize(x, p=2, dim=1) | |||
# x = self.rel_activation(x) | |||
print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
# print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||
next_layer_repr[vt_row].append(x) | |||
@@ -174,6 +175,32 @@ class Model(torch.nn.Module): | |||
cur_layer_repr = next_layer_repr | |||
return next_layer_repr | |||
def decode(self, last_layer_repr: List[torch.Tensor], | |||
batch: TrainingBatch) -> torch.Tensor: | |||
vt_row = batch.vertex_type_row | |||
vt_col = batch.vertex_type_column | |||
rel_idx = batch.relation_type_index | |||
global_interaction = \ | |||
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||
local_variation = \ | |||
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||
in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||
in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||
x = torch.mm(in_row, local_variation) | |||
x = torch.mm(x, global_interaction) | |||
x = torch.mm(x, local_variation) | |||
x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), | |||
in_col.view(in_col.shape[0], in_col.shape[1], 1)) | |||
x = torch.flatten(x) | |||
x = self.dec_activation(x) | |||
return x | |||
def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | |||
edges = [] | |||
cur_edges = batch.edges | |||
@@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||
return torch.sparse.mm(a, b) | |||
else: | |||
return torch.mm(a, b) | |||
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||
List[torch.Tensor]: | |||
tot = sum(vertex_type_counts) | |||
# indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) | |||
# print('indices.shape:', indices.shape) | |||
ofs = 0 | |||
res = [] | |||
for cnt in vertex_type_counts: | |||
ind = torch.cat([ | |||
torch.arange(cnt).view(1, -1), | |||
torch.arange(ofs, ofs+cnt).view(1, -1) | |||
]) | |||
val = torch.ones(cnt) | |||
x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) | |||
res.append(x) | |||
ofs += cnt | |||
return res |
@@ -1,9 +1,46 @@ | |||
import torch | |||
from triacontagon.model import Model | |||
from triacontagon.model import Model, \ | |||
TrainingBatch, \ | |||
_per_layer_required_vertices | |||
from triacontagon.data import Data | |||
from triacontagon.decode import dedicom_decoder | |||
def test_per_layer_required_vertices_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 4) | |||
d.add_vertex_type('Drug', 5) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
[0, 1, 0, 0, 1], | |||
[0, 0, 1, 0, 0], | |||
[1, 0, 0, 0, 1], | |||
[0, 0, 1, 1, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
[0, 1] | |||
])) | |||
res = _per_layer_required_vertices(d, batch, 5) | |||
print('res:', res) | |||
def test_model_convolve_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 4) | |||
@@ -2,7 +2,7 @@ from triacontagon.util import \ | |||
_clear_adjacency_matrix_except_rows, \ | |||
_sparse_diag_cat, \ | |||
_equal, \ | |||
_per_layer_required_vertices | |||
common_one_hot_encoding | |||
from triacontagon.model import TrainingBatch | |||
from triacontagon.decode import dedicom_decoder | |||
from triacontagon.data import Data | |||
@@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||
assert _equal(res, truth).all() | |||
def test_per_layer_required_vertices_01(): | |||
d = Data() | |||
d.add_vertex_type('Gene', 4) | |||
d.add_vertex_type('Drug', 5) | |||
def test_common_one_hot_encoding_01(): | |||
in_repr = common_one_hot_encoding([2000, 200]) | |||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||
[1, 0, 0, 1], | |||
[0, 1, 1, 0], | |||
[0, 0, 1, 0], | |||
[0, 1, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||
[0, 1, 0, 0, 1], | |||
[0, 0, 1, 0, 0], | |||
[1, 0, 0, 0, 1], | |||
[0, 0, 1, 1, 0] | |||
]).to_sparse() ], dedicom_decoder) | |||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||
[1, 0, 0, 0, 0], | |||
[0, 1, 0, 0, 0], | |||
[0, 0, 1, 0, 0], | |||
[0, 0, 0, 1, 0], | |||
[0, 0, 0, 0, 1] | |||
]).to_sparse() ], dedicom_decoder) | |||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||
[0, 1] | |||
])) | |||
res = _per_layer_required_vertices(d, batch, 5) | |||
print('res:', res) | |||
ref = torch.eye(2200) | |||
assert torch.all(in_repr[0].to_dense() == ref[:2000, :]) | |||
assert torch.all(in_repr[1].to_dense() == ref[2000:, :]) |