From d835b049b029fc27c7bd2d04b52665d28ee2438c Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Fri, 15 May 2020 12:19:29 +0200 Subject: [PATCH] test_normalize --- tests/decagon_pytorch/test_normalize.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/decagon_pytorch/test_normalize.py diff --git a/tests/decagon_pytorch/test_normalize.py b/tests/decagon_pytorch/test_normalize.py new file mode 100644 index 0000000..c7e0180 --- /dev/null +++ b/tests/decagon_pytorch/test_normalize.py @@ -0,0 +1,25 @@ +import decagon_pytorch.normalize +import decagon.deep.minibatch +import numpy as np + + +def test_normalize_adjacency_matrix_square(): + mx = np.random.rand(10, 10) + mx[mx < .5] = 0 + mx = np.ceil(mx) + res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx) + res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx) + assert len(res_torch) == len(res_tf) + for i in range(len(res_torch)): + assert np.all(res_torch[i] == res_tf[i]) + + +def test_normalize_adjacency_matrix_nonsquare(): + mx = np.random.rand(5, 10) + mx[mx < .5] = 0 + mx = np.ceil(mx) + res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx) + res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx) + assert len(res_torch) == len(res_tf) + for i in range(len(res_torch)): + assert np.all(res_torch[i] == res_tf[i])