|
|
@@ -35,6 +35,16 @@ def _check_2d(adj_mat): |
|
|
|
raise ValueError('adj_mat must be a square matrix')
|
|
|
|
|
|
|
|
|
|
|
|
def _sparse_coo_tensor(indices, values, size):
|
|
|
|
ctor = { torch.float32: torch.sparse.FloatTensor,
|
|
|
|
torch.float32: torch.sparse.DoubleTensor,
|
|
|
|
torch.uint8: torch.sparse.ByteTensor,
|
|
|
|
torch.long: torch.sparse.LongTensor,
|
|
|
|
torch.int: torch.sparse.IntTensor,
|
|
|
|
torch.short: torch.sparse.ShortTensor }[values.dtype]
|
|
|
|
return ctor(indices, values, size)
|
|
|
|
|
|
|
|
|
|
|
|
def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
|
|
|
|
_check_tensor(adj_mat)
|
|
|
|
_check_sparse(adj_mat)
|
|
|
@@ -53,7 +63,7 @@ def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: |
|
|
|
indices = torch.cat((indices, eye_indices), 1)
|
|
|
|
values = torch.cat((values, eye_values), 0)
|
|
|
|
|
|
|
|
adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
|
|
|
|
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
|
|
|
|
|
|
|
|
return adj_mat
|
|
|
|
|
|
|
@@ -104,7 +114,7 @@ def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: |
|
|
|
degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
|
|
|
|
degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
|
|
|
|
values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
|
|
|
|
adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape)
|
|
|
|
adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
|
|
|
|
|
|
|
|
return adj_mat
|
|
|
|
|
|
|
|