diff --git a/src/icosagon/dropout.py b/src/icosagon/dropout.py index 95d0575..74bdd57 100644 --- a/src/icosagon/dropout.py +++ b/src/icosagon/dropout.py @@ -5,6 +5,7 @@ import torch +from .normalize import _sparse_coo_tensor def dropout_sparse(x, keep_prob): @@ -17,7 +18,7 @@ def dropout_sparse(x, keep_prob): n = torch.floor(n).to(torch.bool) i = i[:,n] v = v[n] - x = torch.sparse_coo_tensor(i, v, size=size) + x = _sparse_coo_tensor(i, v, size=size) return x * (1./keep_prob) diff --git a/src/icosagon/normalize.py b/src/icosagon/normalize.py index b41dfb1..ec09b3c 100644 --- a/src/icosagon/normalize.py +++ b/src/icosagon/normalize.py @@ -35,6 +35,16 @@ def _check_2d(adj_mat): raise ValueError('adj_mat must be a square matrix') +def _sparse_coo_tensor(indices, values, size): + ctor = { torch.float32: torch.sparse.FloatTensor, + torch.float32: torch.sparse.DoubleTensor, + torch.uint8: torch.sparse.ByteTensor, + torch.long: torch.sparse.LongTensor, + torch.int: torch.sparse.IntTensor, + torch.short: torch.sparse.ShortTensor }[values.dtype] + return ctor(indices, values, size) + + def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: _check_tensor(adj_mat) _check_sparse(adj_mat) @@ -53,7 +63,7 @@ def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: indices = torch.cat((indices, eye_indices), 1) values = torch.cat((values, eye_values), 0) - adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape) + adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) return adj_mat @@ -104,7 +114,7 @@ def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) - adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape) + adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape) return adj_mat diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py index 0548462..5271a06 100644 --- a/tests/icosagon/test_trainloop.py +++ b/tests/icosagon/test_trainloop.py @@ -42,6 +42,7 @@ def test_train_loop_02(): def test_train_loop_03(): + # pdb.set_trace() if torch.cuda.device_count() == 0: pytest.skip('CUDA required for this test')