From 5d8c2d08c41e734cb55a1957d85653526c897e0f Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 28 May 2020 10:41:41 +0200 Subject: [PATCH] Add tests for OneHotInputLayer. --- tests/decagon_pytorch/test_layer.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/decagon_pytorch/test_layer.py b/tests/decagon_pytorch/test_layer.py index 4dee8b1..16354ab 100644 --- a/tests/decagon_pytorch/test_layer.py +++ b/tests/decagon_pytorch/test_layer.py @@ -71,6 +71,42 @@ def test_input_layer_03(): assert layer.node_reps[1].device == device +def test_one_hot_input_layer_01(): + d = _some_data() + layer = OneHotInputLayer(d) + assert layer.output_dim == [1000, 100] + assert len(layer.node_reps) == 2 + assert layer.node_reps[0].shape == (1000, 1000) + assert layer.node_reps[1].shape == (100, 100) + assert layer.data == d + assert layer.is_sparse + + +def test_one_hot_input_layer_02(): + d = _some_data() + layer = OneHotInputLayer(d) + res = layer() + assert isinstance(res[0], torch.Tensor) + assert isinstance(res[1], torch.Tensor) + assert res[0].shape == (1000, 1000) + assert res[1].shape == (100, 100) + assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense()) + assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense()) + + +def test_one_hot_input_layer_03(): + if torch.cuda.device_count() == 0: + pytest.skip('No CUDA devices on this host') + d = _some_data() + layer = OneHotInputLayer(d) + device = torch.device('cuda:0') + layer = layer.to(device) + print(list(layer.parameters())) + # assert layer.device.type == 'cuda:0' + assert layer.node_reps[0].device == device + assert layer.node_reps[1].device == device + + def test_decagon_layer_01(): d = _some_data_with_interactions() in_layer = InputLayer(d) @@ -82,3 +118,7 @@ def test_decagon_layer_02(): in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(d, in_layer, output_dim=32) _ = d_layer() # dummy call + + +def test_decagon_layer_03(): + pass