import torch from typing import List, \ Set import time def _equal(x: torch.Tensor, y: torch.Tensor): if x.is_sparse ^ y.is_sparse: raise ValueError('Cannot mix sparse and dense tensors') if not x.is_sparse: return (x == y) return ((x - y).coalesce().values() == 0) def _sparse_coo_tensor(indices, values, size): ctor = { torch.float32: torch.sparse.FloatTensor, torch.float32: torch.sparse.DoubleTensor, torch.uint8: torch.sparse.ByteTensor, torch.long: torch.sparse.LongTensor, torch.int: torch.sparse.IntTensor, torch.short: torch.sparse.ShortTensor, torch.bool: torch.sparse.ByteTensor }[values.dtype] return ctor(indices, values, size) def _nonzero_sum(adjacency_matrices: List[torch.Tensor]): if len(adjacency_matrices) == 0: raise ValueError('adjacency_matrices must be non-empty') if not all([x.is_sparse for x in adjacency_matrices]): raise ValueError('All adjacency matrices must be sparse') indices = [ x.indices() for x in adjacency_matrices ] indices = torch.cat(indices, dim=1) values = torch.ones(indices.shape[1]) res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape) res = res.coalesce() indices = res.indices() res = _sparse_coo_tensor(indices, torch.ones(indices.shape[1], dtype=torch.uint8), adjacency_matrices[0].shape) res = res.coalesce() return res def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor, rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor: if not adjacency_matrix.is_sparse: raise ValueError('adjacency_matrix must be sparse') if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types: raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types') t = time.time() rows = [ rows + row_vertex_count * i \ for i in range(num_relation_types) ] print('rows took:', time.time() - t) t = time.time() rows = torch.cat(rows) print('cat took:', time.time() - t) # print('rows:', rows) # rows = set(rows.tolist()) # print('rows:', rows) t = time.time() adj_mat = adjacency_matrix.coalesce() indices = adj_mat.indices() values = adj_mat.values() print('indices[0]:', indices[0]) # print('indices[0][1]:', indices[0][1], indices[0][1] in rows) lookup = torch.zeros(row_vertex_count * num_relation_types, dtype=torch.uint8, device=adj_mat.device) lookup[rows] = 1 values = values * lookup[indices[0]] mask = torch.nonzero(values > 0, as_tuple=True)[0] indices = indices[:, mask] values = values[mask] res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) # res = res.coalesce() print('res:', res) print('"index_select()" took:', time.time() - t) return res selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ]) # print('selection:', selection) selection = torch.nonzero(selection, as_tuple=True)[0] # print('selection:', selection) indices = indices[:, selection] values = values[selection] print('"index_select()" took:', time.time() - t) t = time.time() res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) print('_sparse_coo_tensor() took:', time.time() - t) return res # t = time.time() # adj_mat = torch.index_select(adjacency_matrix, 0, rows) # print('index_select took:', time.time() - t) t = time.time() adj_mat = adj_mat.coalesce() print('coalesce() took:', time.time() - t) indices = adj_mat.indices() # print('indices:', indices) values = adj_mat.values() t = time.time() indices[0] = rows[indices[0]] print('Lookup took:', time.time() - t) t = time.time() adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape) print('_sparse_coo_tensor() took:', time.time() - t) return adj_mat def _sparse_diag_cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('The list of matrices must be non-empty') if not all(m.is_sparse for m in matrices): raise ValueError('All matrices must be sparse') if not all(len(m.shape) == 2 for m in matrices): raise ValueError('All matrices must be 2D') indices = [] values = [] row_offset = 0 col_offset = 0 for m in matrices: ind = m._indices().clone() ind[0] += row_offset ind[1] += col_offset indices.append(ind) values.append(m._values()) row_offset += m.shape[0] col_offset += m.shape[1] indices = torch.cat(indices, dim=1) values = torch.cat(values) return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset)) def _cat(matrices: List[torch.Tensor]): if len(matrices) == 0: raise ValueError('Empty list passed to _cat()') n = sum(a.is_sparse for a in matrices) if n != 0 and n != len(matrices): raise ValueError('All matrices must have the same layout (dense or sparse)') if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices): raise ValueError('All matrices must have the same dimensions apart from dimension 0') if not matrices[0].is_sparse: return torch.cat(matrices) total_rows = sum(a.shape[0] for a in matrices) indices = [] values = [] row_offset = 0 for a in matrices: ind = a._indices().clone() val = a._values() ind[0] += row_offset ind = ind.transpose(0, 1) indices.append(ind) values.append(val) row_offset += a.shape[0] indices = torch.cat(indices).transpose(0, 1) values = torch.cat(values) res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1])) return res def _mm(a: torch.Tensor, b: torch.Tensor): if a.is_sparse: return torch.sparse.mm(a, b) else: return torch.mm(a, b) def _select_rows(a: torch.Tensor, rows: torch.Tensor): if not a.is_sparse: return a[rows] indices = a.indices() values = a.values() mask = torch.zeros(a.shape[0]) mask[rows] = 1 if mask.sum() != len(rows): raise ValueError('Rows must be unique') mask = mask[indices[0]] mask = torch.nonzero(mask, as_tuple=True)[0] new_rows[rows] = torch.arange(len(rows)) new_rows = new_rows[indices[0]] indices = indices[:, mask] indices[0] = new_rows values = values[mask] res = _sparse_coo_tensor(indices, values, size=(len(rows), a.shape[1])) return res def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \ List[torch.Tensor]: tot = sum(vertex_type_counts) # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0) # print('indices.shape:', indices.shape) ofs = 0 res = [] for cnt in vertex_type_counts: ind = torch.cat([ torch.arange(cnt).view(1, -1), torch.arange(ofs, ofs+cnt).view(1, -1) ]) val = torch.ones(cnt) x = _sparse_coo_tensor(ind, val, size=(cnt, tot)) x = x.to(device) res.append(x) ofs += cnt return res