|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import torch
- from typing import List, \
- Set
- import time
-
-
- def _equal(x: torch.Tensor, y: torch.Tensor):
- if x.is_sparse ^ y.is_sparse:
- raise ValueError('Cannot mix sparse and dense tensors')
-
- if not x.is_sparse:
- return (x == y)
-
- return ((x - y).coalesce().values() == 0)
-
-
- def _sparse_coo_tensor(indices, values, size):
- ctor = { torch.float32: torch.sparse.FloatTensor,
- torch.float32: torch.sparse.DoubleTensor,
- torch.uint8: torch.sparse.ByteTensor,
- torch.long: torch.sparse.LongTensor,
- torch.int: torch.sparse.IntTensor,
- torch.short: torch.sparse.ShortTensor,
- torch.bool: torch.sparse.ByteTensor }[values.dtype]
- return ctor(indices, values, size)
-
-
- def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
- if len(adjacency_matrices) == 0:
- raise ValueError('adjacency_matrices must be non-empty')
-
- if not all([x.is_sparse for x in adjacency_matrices]):
- raise ValueError('All adjacency matrices must be sparse')
-
- indices = [ x.indices() for x in adjacency_matrices ]
- indices = torch.cat(indices, dim=1)
-
- values = torch.ones(indices.shape[1])
- res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
- res = res.coalesce()
-
- indices = res.indices()
- res = _sparse_coo_tensor(indices,
- torch.ones(indices.shape[1], dtype=torch.uint8))
-
- return res
-
-
- def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
- rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
-
- if not adjacency_matrix.is_sparse:
- raise ValueError('adjacency_matrix must be sparse')
-
- if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
- raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
-
- t = time.time()
- rows = [ rows + row_vertex_count * i \
- for i in range(num_relation_types) ]
- print('rows took:', time.time() - t)
- t = time.time()
- rows = torch.cat(rows)
- print('cat took:', time.time() - t)
- # print('rows:', rows)
- # rows = set(rows.tolist())
- # print('rows:', rows)
-
- t = time.time()
- adj_mat = adjacency_matrix.coalesce()
- indices = adj_mat.indices()
- values = adj_mat.values()
- print('indices[0]:', indices[0])
- # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
-
- lookup = torch.zeros(row_vertex_count * num_relation_types,
- dtype=torch.uint8, device=adj_mat.device)
- lookup[rows] = 1
- values = values * lookup[indices[0]]
- mask = torch.nonzero(values > 0, as_tuple=True)[0]
- indices = indices[:, mask]
- values = values[mask]
-
- res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
- # res = res.coalesce()
- print('res:', res)
- print('"index_select()" took:', time.time() - t)
-
- return res
-
- selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
- # print('selection:', selection)
- selection = torch.nonzero(selection, as_tuple=True)[0]
- # print('selection:', selection)
- indices = indices[:, selection]
- values = values[selection]
- print('"index_select()" took:', time.time() - t)
-
- t = time.time()
- res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
- print('_sparse_coo_tensor() took:', time.time() - t)
-
- return res
-
- # t = time.time()
- # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
- # print('index_select took:', time.time() - t)
-
- t = time.time()
- adj_mat = adj_mat.coalesce()
- print('coalesce() took:', time.time() - t)
- indices = adj_mat.indices()
-
- # print('indices:', indices)
-
- values = adj_mat.values()
- t = time.time()
- indices[0] = rows[indices[0]]
- print('Lookup took:', time.time() - t)
-
- t = time.time()
- adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
- print('_sparse_coo_tensor() took:', time.time() - t)
-
- return adj_mat
-
-
- def _sparse_diag_cat(matrices: List[torch.Tensor]):
- if len(matrices) == 0:
- raise ValueError('The list of matrices must be non-empty')
-
- if not all(m.is_sparse for m in matrices):
- raise ValueError('All matrices must be sparse')
-
- if not all(len(m.shape) == 2 for m in matrices):
- raise ValueError('All matrices must be 2D')
-
- indices = []
- values = []
- row_offset = 0
- col_offset = 0
-
- for m in matrices:
- ind = m._indices().clone()
- ind[0] += row_offset
- ind[1] += col_offset
- indices.append(ind)
- values.append(m._values())
- row_offset += m.shape[0]
- col_offset += m.shape[1]
-
- indices = torch.cat(indices, dim=1)
- values = torch.cat(values)
-
- return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
-
-
- def _cat(matrices: List[torch.Tensor]):
- if len(matrices) == 0:
- raise ValueError('Empty list passed to _cat()')
-
- n = sum(a.is_sparse for a in matrices)
- if n != 0 and n != len(matrices):
- raise ValueError('All matrices must have the same layout (dense or sparse)')
-
- if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
- raise ValueError('All matrices must have the same dimensions apart from dimension 0')
-
- if not matrices[0].is_sparse:
- return torch.cat(matrices)
-
- total_rows = sum(a.shape[0] for a in matrices)
- indices = []
- values = []
- row_offset = 0
-
- for a in matrices:
- ind = a._indices().clone()
- val = a._values()
- ind[0] += row_offset
- ind = ind.transpose(0, 1)
- indices.append(ind)
- values.append(val)
- row_offset += a.shape[0]
-
- indices = torch.cat(indices).transpose(0, 1)
- values = torch.cat(values)
-
- res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
- return res
|