IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Add test_model_decode_01().

master
Stanislaw Adaszewski 4年前
コミット
7ea82a6267
3個のファイルの変更73行の追加6行の削除
  1. +21
    -6
      src/triacontagon/model.py
  2. +26
    -0
      src/triacontagon/util.py
  3. +26
    -0
      tests/triacontagon/test_model.py

+ 21
- 6
src/triacontagon/model.py ファイルの表示

@@ -173,7 +173,7 @@ class Model(torch.nn.Module):
for x in next_layer_repr ]
cur_layer_repr = next_layer_repr
return next_layer_repr
return cur_layer_repr
def decode(self, last_layer_repr: List[torch.Tensor],
batch: TrainingBatch) -> torch.Tensor:
@@ -182,12 +182,27 @@ class Model(torch.nn.Module):
vt_col = batch.vertex_type_column
rel_idx = batch.relation_type_index
global_interaction = \
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col)
self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
local_variation = \
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx)
self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
in_row = dropout(last_layer_repr[vt_row], self.keep_prob)
in_col = dropout(last_layer_repr[vt_col], self.keep_prob)
in_row = last_layer_repr[vt_row]
in_col = last_layer_repr[vt_col]
if in_row.is_sparse or in_col.is_sparse:
raise ValueError('Inputs to Model.decode() must be dense')
in_row = in_row[batch.edges[:, 0]]
in_col = in_col[batch.edges[:, 1]]
in_row = dropout(in_row, self.keep_prob)
in_col = dropout(in_col, self.keep_prob)
# in_row = in_row.to_dense()
# in_col = in_col.to_dense()
print('in_row.is_sparse:', in_row.is_sparse)
print('in_col.is_sparse:', in_col.is_sparse)
x = torch.mm(in_row, local_variation)
x = torch.mm(x, global_interaction)
@@ -197,7 +212,7 @@ class Model(torch.nn.Module):
x = torch.flatten(x)
x = self.dec_activation(x)
return x


+ 26
- 0
src/triacontagon/util.py ファイルの表示

@@ -199,6 +199,32 @@ def _mm(a: torch.Tensor, b: torch.Tensor):
return torch.mm(a, b)
def _select_rows(a: torch.Tensor, rows: torch.Tensor):
if not a.is_sparse:
return a[rows]
indices = a.indices()
values = a.values()
mask = torch.zeros(a.shape[0])
mask[rows] = 1
if mask.sum() != len(rows):
raise ValueError('Rows must be unique')
mask = mask[indices[0]]
mask = torch.nonzero(mask, as_tuple=True)[0]
new_rows[rows] = torch.arange(len(rows))
new_rows = new_rows[indices[0]]
indices = indices[:, mask]
indices[0] = new_rows
values = values[mask]
res = _sparse_coo_tensor(indices, values,
size=(len(rows), a.shape[1]))
return res
def common_one_hot_encoding(vertex_type_counts: List[int]) -> \
List[torch.Tensor]:


+ 26
- 0
tests/triacontagon/test_model.py ファイルの表示

@@ -4,6 +4,7 @@ from triacontagon.model import Model, \
_per_layer_required_vertices
from triacontagon.data import Data
from triacontagon.decode import dedicom_decoder
from triacontagon.util import common_one_hot_encoding
def test_per_layer_required_vertices_01():
@@ -83,3 +84,28 @@ def test_model_convolve_01():
]
_ = model.convolve(in_layer_repr)
def test_model_decode_01():
d = Data()
d.add_vertex_type('Gene', 100)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.rand(100, 100).round().to_sparse()
], dedicom_decoder)
b = TrainingBatch(0, 0, 0, torch.tensor([
[0, 1],
[10, 51],
[50, 60],
[70, 90],
[98, 99]
]), torch.ones(5))
in_repr = common_one_hot_encoding([100])
in_repr = [ in_repr[0].to_dense() ]
m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid)
_ = m.decode(in_repr, b)

読み込み中…
キャンセル
保存