IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Add common_one_hot_encoding().

master
Stanislaw Adaszewski 4年前
コミット
b1a5f0d0ee
6個のファイルの変更99行の追加38行の削除
  1. +0
    -0
      src/triacontagon/deprecated/input.py
  2. +2
    -0
      src/triacontagon/dropout.py
  3. +31
    -4
      src/triacontagon/model.py
  4. +22
    -0
      src/triacontagon/util.py
  5. +38
    -1
      tests/triacontagon/test_model.py
  6. +6
    -33
      tests/triacontagon/test_util.py

src/triacontagon/input.py → src/triacontagon/deprecated/input.py ファイルの表示


+ 2
- 0
src/triacontagon/dropout.py ファイルの表示

@@ -36,6 +36,8 @@ def dropout_dense(x, keep_prob):
def dropout(x, keep_prob):
if keep_prob == 1:
return x
if x.is_sparse:
return dropout_sparse(x, keep_prob)
else:


+ 31
- 4
src/triacontagon/model.py ファイルの表示

@@ -22,6 +22,7 @@ class TrainingBatch(object):
vertex_type_column: int
relation_type_index: int
edges: torch.Tensor
target_values: torch.Tensor
def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
@@ -133,7 +134,7 @@ class Model(torch.nn.Module):
List[torch.Tensor]:
cur_layer_repr = in_layer_repr
for i in range(len(self.layer_dimensions) - 1):
next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
@@ -147,7 +148,7 @@ class Model(torch.nn.Module):
if self.keep_prob != 1:
x = dropout(x, self.keep_prob)
print('a, Layer:', i, 'x.shape:', x.shape)
# print('a, Layer:', i, 'x.shape:', x.shape)
x = _mm(x, conv_weights)
x = torch.split(x,
@@ -159,12 +160,12 @@ class Model(torch.nn.Module):
self.data.vertex_types[vt_row].count,
self.layer_dimensions[i + 1])
print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
# print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
x = x.sum(dim=0)
x = torch.nn.functional.normalize(x, p=2, dim=1)
# x = self.rel_activation(x)
print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
# print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
next_layer_repr[vt_row].append(x)
@@ -174,6 +175,32 @@ class Model(torch.nn.Module):
cur_layer_repr = next_layer_repr
return next_layer_repr
def decode(self, last_layer_repr: List[torch.Tensor],
batch: TrainingBatch) -> torch.Tensor:
vt_row = batch.vertex_type_row
vt_col = batch.vertex_type_column
rel_idx = batch.relation_type_index
global_interaction = \
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col)
local_variation = \
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx)
in_row = dropout(last_layer_repr[vt_row], self.keep_prob)
in_col = dropout(last_layer_repr[vt_col], self.keep_prob)
x = torch.mm(in_row, local_variation)
x = torch.mm(x, global_interaction)
x = torch.mm(x, local_variation)
x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]),
in_col.view(in_col.shape[0], in_col.shape[1], 1))
x = torch.flatten(x)
x = self.dec_activation(x)
return x
def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
edges = []
cur_edges = batch.edges


+ 22
- 0
src/triacontagon/util.py ファイルの表示

@@ -197,3 +197,25 @@ def _mm(a: torch.Tensor, b: torch.Tensor):
return torch.sparse.mm(a, b)
else:
return torch.mm(a, b)
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \
List[torch.Tensor]:
tot = sum(vertex_type_counts)
# indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
# print('indices.shape:', indices.shape)
ofs = 0
res = []
for cnt in vertex_type_counts:
ind = torch.cat([
torch.arange(cnt).view(1, -1),
torch.arange(ofs, ofs+cnt).view(1, -1)
])
val = torch.ones(cnt)
x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
res.append(x)
ofs += cnt
return res

+ 38
- 1
tests/triacontagon/test_model.py ファイルの表示

@@ -1,9 +1,46 @@
import torch
from triacontagon.model import Model
from triacontagon.model import Model, \
TrainingBatch, \
_per_layer_required_vertices
from triacontagon.data import Data
from triacontagon.decode import dedicom_decoder
def test_per_layer_required_vertices_01():
d = Data()
d.add_vertex_type('Gene', 4)
d.add_vertex_type('Drug', 5)
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 1, 0],
[0, 1, 0, 1]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 1],
[0, 0, 1, 1, 0]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]).to_sparse() ], dedicom_decoder)
batch = TrainingBatch(0, 1, 0, torch.tensor([
[0, 1]
]))
res = _per_layer_required_vertices(d, batch, 5)
print('res:', res)
def test_model_convolve_01():
d = Data()
d.add_vertex_type('Gene', 4)


+ 6
- 33
tests/triacontagon/test_util.py ファイルの表示

@@ -2,7 +2,7 @@ from triacontagon.util import \
_clear_adjacency_matrix_except_rows, \
_sparse_diag_cat, \
_equal, \
_per_layer_required_vertices
common_one_hot_encoding
from triacontagon.model import TrainingBatch
from triacontagon.decode import dedicom_decoder
from triacontagon.data import Data
@@ -127,36 +127,9 @@ def test_clear_adjacency_matrix_except_rows_05():
assert _equal(res, truth).all()
def test_per_layer_required_vertices_01():
d = Data()
d.add_vertex_type('Gene', 4)
d.add_vertex_type('Drug', 5)
def test_common_one_hot_encoding_01():
in_repr = common_one_hot_encoding([2000, 200])
d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 1, 0],
[0, 1, 0, 1]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 1],
[0, 0, 1, 1, 0]
]).to_sparse() ], dedicom_decoder)
d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]).to_sparse() ], dedicom_decoder)
batch = TrainingBatch(0, 1, 0, torch.tensor([
[0, 1]
]))
res = _per_layer_required_vertices(d, batch, 5)
print('res:', res)
ref = torch.eye(2200)
assert torch.all(in_repr[0].to_dense() == ref[:2000, :])
assert torch.all(in_repr[1].to_dense() == ref[2000:, :])

読み込み中…
キャンセル
保存