@@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob): | |||||
def dropout(x, keep_prob): | def dropout(x, keep_prob): | ||||
if keep_prob == 1: | |||||
return x | |||||
if x.is_sparse: | if x.is_sparse: | ||||
return dropout_sparse(x, keep_prob) | return dropout_sparse(x, keep_prob) | ||||
else: | else: | ||||
@@ -22,6 +22,7 @@ class TrainingBatch(object): | |||||
vertex_type_column: int | vertex_type_column: int | ||||
relation_type_index: int | relation_type_index: int | ||||
edges: torch.Tensor | edges: torch.Tensor | ||||
target_values: torch.Tensor | |||||
def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | def _per_layer_required_vertices(data: Data, batch: TrainingBatch, | ||||
@@ -133,7 +134,7 @@ class Model(torch.nn.Module): | |||||
List[torch.Tensor]: | List[torch.Tensor]: | ||||
cur_layer_repr = in_layer_repr | cur_layer_repr = in_layer_repr | ||||
for i in range(len(self.layer_dimensions) - 1): | for i in range(len(self.layer_dimensions) - 1): | ||||
next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ] | ||||
@@ -147,7 +148,7 @@ class Model(torch.nn.Module): | |||||
if self.keep_prob != 1: | if self.keep_prob != 1: | ||||
x = dropout(x, self.keep_prob) | x = dropout(x, self.keep_prob) | ||||
print('a, Layer:', i, 'x.shape:', x.shape) | |||||
# print('a, Layer:', i, 'x.shape:', x.shape) | |||||
x = _mm(x, conv_weights) | x = _mm(x, conv_weights) | ||||
x = torch.split(x, | x = torch.split(x, | ||||
@@ -159,12 +160,12 @@ class Model(torch.nn.Module): | |||||
self.data.vertex_types[vt_row].count, | self.data.vertex_types[vt_row].count, | ||||
self.layer_dimensions[i + 1]) | self.layer_dimensions[i + 1]) | ||||
print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
# print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
x = x.sum(dim=0) | x = x.sum(dim=0) | ||||
x = torch.nn.functional.normalize(x, p=2, dim=1) | x = torch.nn.functional.normalize(x, p=2, dim=1) | ||||
# x = self.rel_activation(x) | # x = self.rel_activation(x) | ||||
print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
# print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape) | |||||
next_layer_repr[vt_row].append(x) | next_layer_repr[vt_row].append(x) | ||||
@@ -174,6 +175,32 @@ class Model(torch.nn.Module): | |||||
cur_layer_repr = next_layer_repr | cur_layer_repr = next_layer_repr | ||||
return next_layer_repr | return next_layer_repr | ||||
def decode(self, last_layer_repr: List[torch.Tensor], | |||||
batch: TrainingBatch) -> torch.Tensor: | |||||
vt_row = batch.vertex_type_row | |||||
vt_col = batch.vertex_type_column | |||||
rel_idx = batch.relation_type_index | |||||
global_interaction = \ | |||||
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col) | |||||
local_variation = \ | |||||
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx) | |||||
in_row = dropout(last_layer_repr[vt_row], self.keep_prob) | |||||
in_col = dropout(last_layer_repr[vt_col], self.keep_prob) | |||||
x = torch.mm(in_row, local_variation) | |||||
x = torch.mm(x, global_interaction) | |||||
x = torch.mm(x, local_variation) | |||||
x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]), | |||||
in_col.view(in_col.shape[0], in_col.shape[1], 1)) | |||||
x = torch.flatten(x) | |||||
x = self.dec_activation(x) | |||||
return x | |||||
def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]: | ||||
edges = [] | edges = [] | ||||
cur_edges = batch.edges | cur_edges = batch.edges | ||||
@@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor): | |||||
return torch.sparse.mm(a, b) | return torch.sparse.mm(a, b) | ||||
else: | else: | ||||
return torch.mm(a, b) | return torch.mm(a, b) | ||||
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \ | |||||
List[torch.Tensor]: | |||||
tot = sum(vertex_type_counts) | |||||
# indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) | |||||
# print('indices.shape:', indices.shape) | |||||
ofs = 0 | |||||
res = [] | |||||
for cnt in vertex_type_counts: | |||||
ind = torch.cat([ | |||||
torch.arange(cnt).view(1, -1), | |||||
torch.arange(ofs, ofs+cnt).view(1, -1) | |||||
]) | |||||
val = torch.ones(cnt) | |||||
x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) | |||||
res.append(x) | |||||
ofs += cnt | |||||
return res |
@@ -1,9 +1,46 @@ | |||||
import torch | import torch | ||||
from triacontagon.model import Model | |||||
from triacontagon.model import Model, \ | |||||
TrainingBatch, \ | |||||
_per_layer_required_vertices | |||||
from triacontagon.data import Data | from triacontagon.data import Data | ||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
def test_per_layer_required_vertices_01(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 4) | |||||
d.add_vertex_type('Drug', 5) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
[1, 0, 0, 1], | |||||
[0, 1, 1, 0], | |||||
[0, 0, 1, 0], | |||||
[0, 1, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
[0, 1, 0, 0, 1], | |||||
[0, 0, 1, 0, 0], | |||||
[1, 0, 0, 0, 1], | |||||
[0, 0, 1, 1, 0] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
[1, 0, 0, 0, 0], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
[0, 1] | |||||
])) | |||||
res = _per_layer_required_vertices(d, batch, 5) | |||||
print('res:', res) | |||||
def test_model_convolve_01(): | def test_model_convolve_01(): | ||||
d = Data() | d = Data() | ||||
d.add_vertex_type('Gene', 4) | d.add_vertex_type('Gene', 4) | ||||
@@ -2,7 +2,7 @@ from triacontagon.util import \ | |||||
_clear_adjacency_matrix_except_rows, \ | _clear_adjacency_matrix_except_rows, \ | ||||
_sparse_diag_cat, \ | _sparse_diag_cat, \ | ||||
_equal, \ | _equal, \ | ||||
_per_layer_required_vertices | |||||
common_one_hot_encoding | |||||
from triacontagon.model import TrainingBatch | from triacontagon.model import TrainingBatch | ||||
from triacontagon.decode import dedicom_decoder | from triacontagon.decode import dedicom_decoder | ||||
from triacontagon.data import Data | from triacontagon.data import Data | ||||
@@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05(): | |||||
assert _equal(res, truth).all() | assert _equal(res, truth).all() | ||||
def test_per_layer_required_vertices_01(): | |||||
d = Data() | |||||
d.add_vertex_type('Gene', 4) | |||||
d.add_vertex_type('Drug', 5) | |||||
def test_common_one_hot_encoding_01(): | |||||
in_repr = common_one_hot_encoding([2000, 200]) | |||||
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([ | |||||
[1, 0, 0, 1], | |||||
[0, 1, 1, 0], | |||||
[0, 0, 1, 0], | |||||
[0, 1, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([ | |||||
[0, 1, 0, 0, 1], | |||||
[0, 0, 1, 0, 0], | |||||
[1, 0, 0, 0, 1], | |||||
[0, 0, 1, 1, 0] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([ | |||||
[1, 0, 0, 0, 0], | |||||
[0, 1, 0, 0, 0], | |||||
[0, 0, 1, 0, 0], | |||||
[0, 0, 0, 1, 0], | |||||
[0, 0, 0, 0, 1] | |||||
]).to_sparse() ], dedicom_decoder) | |||||
batch = TrainingBatch(0, 1, 0, torch.tensor([ | |||||
[0, 1] | |||||
])) | |||||
res = _per_layer_required_vertices(d, batch, 5) | |||||
print('res:', res) | |||||
ref = torch.eye(2200) | |||||
assert torch.all(in_repr[0].to_dense() == ref[:2000, :]) | |||||
assert torch.all(in_repr[1].to_dense() == ref[2000:, :]) |