diff --git a/tests/decagon_pytorch/test_normalize.py b/tests/decagon_pytorch/test_normalize.py new file mode 100644 index 0000000..c7e0180 --- /dev/null +++ b/tests/decagon_pytorch/test_normalize.py @@ -0,0 +1,25 @@ +import decagon_pytorch.normalize +import decagon.deep.minibatch +import numpy as np + + +def test_normalize_adjacency_matrix_square(): + mx = np.random.rand(10, 10) + mx[mx < .5] = 0 + mx = np.ceil(mx) + res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx) + res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx) + assert len(res_torch) == len(res_tf) + for i in range(len(res_torch)): + assert np.all(res_torch[i] == res_tf[i]) + + +def test_normalize_adjacency_matrix_nonsquare(): + mx = np.random.rand(5, 10) + mx[mx < .5] = 0 + mx = np.ceil(mx) + res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx) + res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx) + assert len(res_torch) == len(res_tf) + for i in range(len(res_torch)): + assert np.all(res_torch[i] == res_tf[i])