IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test for DecagonLayer.

master
Stanislaw Adaszewski 3 years ago
parent
commit
29e81f4eba
2 changed files with 34 additions and 2 deletions
  1. +4
    -1
      src/decagon_pytorch/convolve.py
  2. +30
    -1
      tests/decagon_pytorch/test_layer.py

+ 4
- 1
src/decagon_pytorch/convolve.py View File

@@ -195,9 +195,12 @@ class SparseDropoutGraphConvActivation(torch.nn.Module):
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
self.input_dim = input_dim
self.output_dim = output_dim
self.adjacency_matrix = adjacency_matrix
self.keep_prob = keep_prob
self.activation = activation
self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = dropout_sparse(x, self.keep_prob)


+ 30
- 1
tests/decagon_pytorch/test_layer.py View File

@@ -4,6 +4,7 @@ from decagon_pytorch.layer import InputLayer, \
from decagon_pytorch.data import Data
import torch
import pytest
from decagon_pytorch.convolve import SparseDropoutGraphConvActivation
def _some_data():
@@ -121,4 +122,32 @@ def test_decagon_layer_02():
def test_decagon_layer_03():
pass
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(d, in_layer, output_dim=32)
assert d_layer.data == d
assert d_layer.previous_layer == in_layer
assert d_layer.input_dim == [ 1000, 100 ]
assert not d_layer.is_sparse
assert d_layer.keep_prob == 1.
assert d_layer.rel_activation(0.5) == 0.5
x = torch.tensor([-1, 0, 0.5, 1])
assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
assert len(d_layer.next_layer_repr) == 2
assert len(d_layer.next_layer_repr[0]) == 2
assert len(d_layer.next_layer_repr[1]) == 4
assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
d_layer.next_layer_repr[0]))
assert all(map(lambda a: isinstance(a[0], SparseDropoutGraphConvActivation),
d_layer.next_layer_repr[1]))
assert all(map(lambda a: a[0].output_dim == 32,
d_layer.next_layer_repr[0]))
assert all(map(lambda a: a[0].output_dim == 32,
d_layer.next_layer_repr[1]))
def test_decagon_layer_04():
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(d, in_layer, output_dim=32)
_ = d_layer()

Loading…
Cancel
Save