IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Parcourir la source

Fixes for bool sparse tensors.

master
Stanislaw Adaszewski il y a 4 ans
Parent
révision
1ee89509ee
3 fichiers modifiés avec 5 ajouts et 4 suppressions
  1. +2
    -1
      src/icosagon/normalize.py
  2. +2
    -2
      tests/icosagon/test_normalize.py
  3. +1
    -1
      tests/icosagon/test_trainprep.py

+ 2
- 1
src/icosagon/normalize.py Voir le fichier

@@ -41,7 +41,8 @@ def _sparse_coo_tensor(indices, values, size):
torch.uint8: torch.sparse.ByteTensor,
torch.long: torch.sparse.LongTensor,
torch.int: torch.sparse.IntTensor,
torch.short: torch.sparse.ShortTensor }[values.dtype]
torch.short: torch.sparse.ShortTensor,
torch.bool: torch.sparse.ByteTensor }[values.dtype]
return ctor(indices, values, size)


+ 2
- 2
tests/icosagon/test_normalize.py Voir le fichier

@@ -46,14 +46,14 @@ def test_add_eye_sparse_04():
def test_norm_adj_mat_one_node_type_sparse_01():
adj_mat = torch.rand((10, 10))
adj_mat = (adj_mat > .5)
adj_mat = (adj_mat > .5).to(torch.float32)
adj_mat = adj_mat.to_sparse()
_ = norm_adj_mat_one_node_type_sparse(adj_mat)
def test_norm_adj_mat_one_node_type_sparse_02():
adj_mat_dense = torch.rand((10, 10))
adj_mat_dense = (adj_mat_dense > .5)
adj_mat_dense = (adj_mat_dense > .5).to(torch.float32)
adj_mat_sparse = adj_mat_dense.to_sparse()
adj_mat_sparse = norm_adj_mat_one_node_type_sparse(adj_mat_sparse)
adj_mat_dense = norm_adj_mat_one_node_type_dense(adj_mat_dense)


+ 1
- 1
tests/icosagon/test_trainprep.py Voir le fichier

@@ -108,7 +108,7 @@ def test_prepare_adj_mat_02():
def test_prepare_relation_type_01():
adj_mat = (torch.rand((10, 10)) > .5)
adj_mat = (torch.rand((10, 10)) > .5).to(torch.float32)
r = RelationType('Test', 0, 0, adj_mat, True)
ratios = TrainValTest(.8, .1, .1)
_ = prepare_relation_type(r, ratios, False)


Chargement…
Annuler
Enregistrer