IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Pārlūkot izejas kodu

Add tests for OneHotInputLayer.

master
Stanislaw Adaszewski pirms 4 gadiem
vecāks
revīzija
5d8c2d08c4
1 mainītis faili ar 40 papildinājumiem un 0 dzēšanām
  1. +40
    -0
      tests/decagon_pytorch/test_layer.py

+ 40
- 0
tests/decagon_pytorch/test_layer.py Parādīt failu

@@ -71,6 +71,42 @@ def test_input_layer_03():
assert layer.node_reps[1].device == device
def test_one_hot_input_layer_01():
d = _some_data()
layer = OneHotInputLayer(d)
assert layer.output_dim == [1000, 100]
assert len(layer.node_reps) == 2
assert layer.node_reps[0].shape == (1000, 1000)
assert layer.node_reps[1].shape == (100, 100)
assert layer.data == d
assert layer.is_sparse
def test_one_hot_input_layer_02():
d = _some_data()
layer = OneHotInputLayer(d)
res = layer()
assert isinstance(res[0], torch.Tensor)
assert isinstance(res[1], torch.Tensor)
assert res[0].shape == (1000, 1000)
assert res[1].shape == (100, 100)
assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
def test_one_hot_input_layer_03():
if torch.cuda.device_count() == 0:
pytest.skip('No CUDA devices on this host')
d = _some_data()
layer = OneHotInputLayer(d)
device = torch.device('cuda:0')
layer = layer.to(device)
print(list(layer.parameters()))
# assert layer.device.type == 'cuda:0'
assert layer.node_reps[0].device == device
assert layer.node_reps[1].device == device
def test_decagon_layer_01():
d = _some_data_with_interactions()
in_layer = InputLayer(d)
@@ -82,3 +118,7 @@ def test_decagon_layer_02():
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(d, in_layer, output_dim=32)
_ = d_layer() # dummy call
def test_decagon_layer_03():
pass

Notiek ielāde…
Atcelt
Saglabāt