IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Kaynağa Gözat

Clean up _equal().

master
Stanislaw Adaszewski 4 yıl önce
ebeveyn
işleme
8aafb6fa01
1 değiştirilmiş dosya ile 0 ekleme ve 25 silme
  1. +0
    -25
      src/icosagon/data.py

+ 0
- 25
src/icosagon/data.py Dosyayı Görüntüle

@@ -24,33 +24,8 @@ def _equal(x: torch.Tensor, y: torch.Tensor):
if not x.is_sparse:
return (x == y)
# if x.shape != y.shape:
# return torch.tensor(0, dtype=torch.uint8)
return ((x - y).coalesce().values() == 0)
x = x.coalesce()
indices_x = np.empty(x.indices().shape[1], dtype=np.object)
indices_x[:] = list(map(tuple, x.indices().transpose(0, 1)))
order_x = np.argsort(indices_x)
#order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
y = y.coalesce()
indices_y = np.empty(y.indices().shape[1], dtype=np.object)
indices_y[:] = list(map(tuple, y.indices().transpose(0, 1)))
order_y = np.argsort(indices_y)
# order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
# print(indices_x.shape, indices_y.shape)
if not len(indices_x) == len(indices_y):
return torch.tensor(0, dtype=torch.uint8)
if not np.all(indices_x[order_x] == indices_y[order_y]):
return torch.tensor(0, dtype=torch.uint8)
return (x.values()[order_x] == y.values()[order_y])
@dataclass
class NodeType(object):


Yükleniyor…
İptal
Kaydet