# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # import numpy as np import torch import torch.utils.data from typing import List, \ Union, \ Tuple from .data import Data, \ EdgeType from .cumcount import cumcount import time import multiprocessing import multiprocessing.pool from itertools import product, \ repeat from functools import reduce def fixed_unigram_candidate_sampler( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, distortion: float = 1.) -> torch.Tensor: assert isinstance(true_classes, torch.Tensor) assert isinstance(num_repeats, torch.Tensor) assert isinstance(unigrams, torch.Tensor) distortion = float(distortion) if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') if len(num_repeats.shape) != 1: raise ValueError('num_repeats must be 1D') if torch.any((unigrams > 0).sum() - \ (true_classes >= 0).sum(dim=1) < \ num_repeats): raise ValueError('Not enough classes to choose from') true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1) true_classes = torch.cat([ true_classes, torch.full(( len(true_classes), torch.max(num_repeats) ), -1, dtype=true_classes.dtype) ], dim=1) indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats) indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), indices.view(-1, 1) ], dim=1) result = torch.zeros(len(indices), dtype=torch.long) while len(indices) > 0: candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) candidates = torch.tensor(list(candidates)).view(-1, 1) inner_order = torch.argsort(candidates[:, 0]) indices_np = indices[inner_order].detach().cpu().numpy() outer_order = np.argsort(indices_np[:, 1], kind='stable') outer_order = torch.tensor(outer_order, device=inner_order.device) candidates = candidates[inner_order][outer_order] indices = indices[inner_order][outer_order] mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool) can_cum = cumcount(candidates[:, 0]) ind_cum = cumcount(indices[:, 1]) repeated = (can_cum > 0) & (ind_cum > 0) mask = mask | repeated updated = indices[~mask] if len(updated) > 0: ofs = true_class_count[updated[:, 1]] + \ cumcount(updated[:, 1]) true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1) true_class_count[updated[:, 1]] = ofs + 1 result[indices[:, 0]] = candidates.transpose(0, 1) indices = indices[mask] return result def fixed_unigram_candidate_sampler_slow( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, distortion: float = 1.) -> torch.Tensor: assert isinstance(true_classes, torch.Tensor) assert isinstance(num_repeats, torch.Tensor) assert isinstance(unigrams, torch.Tensor) distortion = float(distortion) if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') if len(num_repeats.shape) != 1: raise ValueError('num_repeats must be 1D') if torch.any((unigrams > 0).sum() - \ (true_classes >= 0).sum(dim=1) < \ num_repeats): raise ValueError('Not enough classes to choose from') res = [] if distortion != 1.: unigrams = unigrams.to(torch.float64) unigrams = unigrams ** distortion def fun(i): if i and i % 100 == 0: print(i) if num_repeats[i] == 0: return [] pos = torch.flatten(true_classes[i, :]) pos = pos[pos >= 0] w = unigrams.clone().detach() w[pos] = 0 sampler = torch.utils.data.WeightedRandomSampler(w, num_repeats[i].item(), replacement=False) res = list(sampler) return res with multiprocessing.pool.ThreadPool() as p: res = p.map(fun, range(len(num_repeats))) res = reduce(list.__add__, res, []) return torch.tensor(res) def fixed_unigram_candidate_sampler_old( true_classes: torch.Tensor, num_repeats: torch.Tensor, unigrams: torch.Tensor, distortion: float = 1.) -> torch.Tensor: if len(true_classes.shape) != 2: raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') if len(num_repeats.shape) != 1: raise ValueError('num_repeats must be 1D') if torch.any((unigrams > 0).sum() - \ (true_classes >= 0).sum(dim=1) < \ num_repeats): raise ValueError('Not enough classes to choose from') num_rows = true_classes.shape[0] print('true_classes.shape:', true_classes.shape) # unigrams = np.array(unigrams) if distortion != 1.: unigrams = unigrams.to(torch.float64) ** distortion print('unigrams:', unigrams) indices = torch.arange(num_rows) indices = torch.repeat_interleave(indices, num_repeats) indices = torch.cat([ torch.arange(len(indices)).view(-1, 1), indices.view(-1, 1) ], dim=1) num_samples = len(indices) result = torch.zeros(num_samples, dtype=torch.long) print('num_rows:', num_rows, 'num_samples:', num_samples) while len(indices) > 0: print('len(indices):', len(indices)) print('indices:', indices) sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) candidates = torch.tensor(list(sampler)) candidates = candidates.view(len(indices), 1) print('candidates:', candidates) print('true_classes:', true_classes[indices[:, 1], :]) result[indices[:, 0]] = candidates.transpose(0, 1) print('result:', result) mask = (candidates == true_classes[indices[:, 1], :]) mask = mask.sum(1).to(torch.bool) # append_true_classes = torch.full(( len(true_classes), ), -1) # append_true_classes[~mask] = torch.flatten(candidates)[~mask] # true_classes = torch.cat([ # append_true_classes.view(-1, 1), # true_classes # ], dim=1) print('mask:', mask) indices = indices[mask] # result[indices] = 0 return result def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ Tuple[torch.Tensor, torch.Tensor]: if adj_mat.is_sparse: adj_mat = adj_mat.coalesce() degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, device=adj_mat.device) degrees = degrees.index_add(0, adj_mat.indices()[1], torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, device=adj_mat.device)) edges_pos = adj_mat.indices().transpose(0, 1) else: degrees = adj_mat.sum(0) edges_pos = torch.nonzero(adj_mat, as_tuple=False) return edges_pos, degrees def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: indices = adj_mat.indices() row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) #print('indices[0]:', indices[0], count[indices[0]]) row_count = row_count.index_add(0, indices[0], torch.ones(indices.shape[1], dtype=torch.long)) #print('count:', count) max_true_classes = torch.max(row_count).item() #print('max_true_classes:', max_true_classes) true_classes = torch.full((adj_mat.shape[0], max_true_classes), -1, dtype=torch.long) # inv = torch.unique(indices[0], return_inverse=True) # indices = indices.copy() # true_classes[indices[0], 0] = indices[1] t = time.time() cc = cumcount(indices[0]) print('cumcount() took:', time.time() - t) # cc = torch.tensor(cc) t = time.time() true_classes[indices[0], cc] = indices[1] print('assignment took:', time.time() - t) ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long) for i in range(indices.shape[1]): # print('looping...') row = indices[0, i] col = indices[1, i] #print('row:', row, 'col:', col, 'count[row]:', count[row]) true_classes[row, count[row]] = col count[row] += 1 ''' # t = time.time() # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) # print('repeat_interleave() took:', time.time() - t) return true_classes, row_count def negative_sample_adj_mat(adj_mat: torch.Tensor, remove_diagonal: bool=False) -> torch.Tensor: if not isinstance(adj_mat, torch.Tensor): raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__) edges_pos, degrees = get_edges_and_degrees(adj_mat) degrees = degrees.to(torch.float32) + 1.0 / torch.numel(adj_mat) true_classes, row_count = get_true_classes(adj_mat) if remove_diagonal: true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1), true_classes ], dim=1) # true_classes = edges_pos[:, 1].view(-1, 1) # print('true_classes:', true_classes) neg_neighbors = fixed_unigram_candidate_sampler( true_classes, row_count, degrees, 0.75).to(adj_mat.device) print('neg_neighbors:', neg_neighbors) pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)), row_count) edges_neg = torch.cat([ pos_vertices.view(-1, 1), neg_neighbors.view(-1, 1) ], 1) adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1), values=torch.ones(len(edges_neg)), size=adj_mat.shape, dtype=adj_mat.dtype, device=adj_mat.device) adj_mat_neg = adj_mat_neg.coalesce() indices = adj_mat_neg.indices() adj_mat_neg = torch.sparse_coo_tensor(indices, torch.ones(indices.shape[1]), adj_mat.shape, dtype=adj_mat.dtype, device=adj_mat.device) adj_mat_neg = adj_mat_neg.coalesce() return adj_mat_neg def negative_sample_data(data: Data) -> Data: new_edge_types = {} res = Data(target_value=0) for vt in data.vertex_types: res.add_vertex_type(vt.name, vt.count) for key, et in data.edge_types.items(): print('key:', key) adjacency_matrices_neg = [] for adj_mat in et.adjacency_matrices: remove_diagonal = True \ if et.vertex_type_row == et.vertex_type_column \ else False adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal) adjacency_matrices_neg.append(adj_mat_neg) res.add_edge_type(et.name, et.vertex_type_row, et.vertex_type_column, adjacency_matrices_neg, et.decoder_factory) #new_et = EdgeType(et.name, et.vertex_type_row, # et.vertex_type_column, adjacency_matrices_neg, # et.decoder_factory, et.total_connectivity) #new_edge_types[key] = new_et #res = Data(data.vertex_types, new_edge_types) return res def merge_data(pos_data: Data, neg_data: Data) -> Data: assert isinstance(pos_data, Data) assert isinstance(neg_data, Data) res = PosNegData() for vt in pos_data.vertex_types: res.add_vertex_type(vt.name, vt.count) for key, pos_et in pos_data.edge_types.items(): neg_et = neg_data.edge_types[key] res.add_edge_type(pos_et.name, pos_et.vertex_type_row, pos_et.vertex_type_column, pos_et.adjacency_matrices, neg_et.adjacency_matrices, pos_et.decoder_factory)