IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
瀏覽代碼

Basic constructor test for DecagonLayer passes.

master
Stanislaw Adaszewski 4 年之前
父節點
當前提交
2188dc7ae2
共有 2 個檔案被更改,包括 12 行新增12 行删除
  1. +11
    -10
      src/decagon_pytorch/layer.py
  2. +1
    -2
      tests/decagon_pytorch/test_layer.py

+ 11
- 10
src/decagon_pytorch/layer.py 查看文件

@@ -91,16 +91,17 @@ class DecagonLayer(Layer):
def build(self):
self.next_layer_repr = defaultdict(list)
for (nt_row, nt_col), rel in self.data.relation_types.items():
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col],
self.output_dim[nt_row], rel.adjacency_matrix,
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_row].append((conv, nt_col))
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row],
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_col].append((conv, nt_row))
for (nt_row, nt_col), relation_types in self.data.relation_types.items():
for rel in relation_types:
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col],
self.output_dim[nt_row], rel.adjacency_matrix,
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_row].append((conv, nt_col))
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row],
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_col].append((conv, nt_row))
def __call__(self):
prev_layer_repr = self.previous_layer()


+ 1
- 2
tests/decagon_pytorch/test_layer.py 查看文件

@@ -70,8 +70,7 @@ def test_input_layer_03():
assert layer.node_reps[1].device == device
@pytest.mark.skip()
def test_decagon_layer_01():
d = _some_data_with_interactions()
in_layer = InputLayer(d)
d_layer = DecagonLayer(in_layer, output_dim=32)
d_layer = DecagonLayer(d, in_layer, output_dim=32)

Loading…
取消
儲存