| @@ -71,6 +71,42 @@ def test_input_layer_03(): | |||||
| assert layer.node_reps[1].device == device | assert layer.node_reps[1].device == device | ||||
| def test_one_hot_input_layer_01(): | |||||
| d = _some_data() | |||||
| layer = OneHotInputLayer(d) | |||||
| assert layer.output_dim == [1000, 100] | |||||
| assert len(layer.node_reps) == 2 | |||||
| assert layer.node_reps[0].shape == (1000, 1000) | |||||
| assert layer.node_reps[1].shape == (100, 100) | |||||
| assert layer.data == d | |||||
| assert layer.is_sparse | |||||
| def test_one_hot_input_layer_02(): | |||||
| d = _some_data() | |||||
| layer = OneHotInputLayer(d) | |||||
| res = layer() | |||||
| assert isinstance(res[0], torch.Tensor) | |||||
| assert isinstance(res[1], torch.Tensor) | |||||
| assert res[0].shape == (1000, 1000) | |||||
| assert res[1].shape == (100, 100) | |||||
| assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense()) | |||||
| assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense()) | |||||
| def test_one_hot_input_layer_03(): | |||||
| if torch.cuda.device_count() == 0: | |||||
| pytest.skip('No CUDA devices on this host') | |||||
| d = _some_data() | |||||
| layer = OneHotInputLayer(d) | |||||
| device = torch.device('cuda:0') | |||||
| layer = layer.to(device) | |||||
| print(list(layer.parameters())) | |||||
| # assert layer.device.type == 'cuda:0' | |||||
| assert layer.node_reps[0].device == device | |||||
| assert layer.node_reps[1].device == device | |||||
| def test_decagon_layer_01(): | def test_decagon_layer_01(): | ||||
| d = _some_data_with_interactions() | d = _some_data_with_interactions() | ||||
| in_layer = InputLayer(d) | in_layer = InputLayer(d) | ||||
| @@ -82,3 +118,7 @@ def test_decagon_layer_02(): | |||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| d_layer = DecagonLayer(d, in_layer, output_dim=32) | d_layer = DecagonLayer(d, in_layer, output_dim=32) | ||||
| _ = d_layer() # dummy call | _ = d_layer() # dummy call | ||||
| def test_decagon_layer_03(): | |||||
| pass | |||||