|
|
@@ -24,33 +24,8 @@ def _equal(x: torch.Tensor, y: torch.Tensor): |
|
|
|
if not x.is_sparse:
|
|
|
|
return (x == y)
|
|
|
|
|
|
|
|
# if x.shape != y.shape:
|
|
|
|
# return torch.tensor(0, dtype=torch.uint8)
|
|
|
|
|
|
|
|
return ((x - y).coalesce().values() == 0)
|
|
|
|
|
|
|
|
x = x.coalesce()
|
|
|
|
indices_x = np.empty(x.indices().shape[1], dtype=np.object)
|
|
|
|
indices_x[:] = list(map(tuple, x.indices().transpose(0, 1)))
|
|
|
|
order_x = np.argsort(indices_x)
|
|
|
|
#order_x = sorted(range(len(indices_x)), key=lambda idx: indices_x[idx])
|
|
|
|
|
|
|
|
y = y.coalesce()
|
|
|
|
indices_y = np.empty(y.indices().shape[1], dtype=np.object)
|
|
|
|
indices_y[:] = list(map(tuple, y.indices().transpose(0, 1)))
|
|
|
|
order_y = np.argsort(indices_y)
|
|
|
|
# order_y = sorted(range(len(indices_y)), key=lambda idx: indices_y[idx])
|
|
|
|
|
|
|
|
# print(indices_x.shape, indices_y.shape)
|
|
|
|
|
|
|
|
if not len(indices_x) == len(indices_y):
|
|
|
|
return torch.tensor(0, dtype=torch.uint8)
|
|
|
|
|
|
|
|
if not np.all(indices_x[order_x] == indices_y[order_y]):
|
|
|
|
return torch.tensor(0, dtype=torch.uint8)
|
|
|
|
|
|
|
|
return (x.values()[order_x] == y.values()[order_y])
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class NodeType(object):
|
|
|
|