|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- import numpy as np
- import torch
- import torch.utils.data
- from typing import List, \
- Union, \
- Tuple
- from .data import Data, \
- EdgeType
- from .cumcount import cumcount
- import time
- import multiprocessing
- import multiprocessing.pool
- from itertools import product, \
- repeat
- from functools import reduce
-
-
- def fixed_unigram_candidate_sampler_slow(
- true_classes: torch.Tensor,
- num_repeats: torch.Tensor,
- unigrams: torch.Tensor,
- distortion: float = 1.) -> torch.Tensor:
-
- assert isinstance(true_classes, torch.Tensor)
- assert isinstance(num_repeats, torch.Tensor)
- assert isinstance(unigrams, torch.Tensor)
- distortion = float(distortion)
-
- if len(true_classes.shape) != 2:
- raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
-
- if len(num_repeats.shape) != 1:
- raise ValueError('num_repeats must be 1D')
-
- if torch.any((unigrams > 0).sum() - \
- (true_classes >= 0).sum(dim=1) < \
- num_repeats):
- raise ValueError('Not enough classes to choose from')
-
- res = []
-
- if distortion != 1.:
- unigrams = unigrams.to(torch.float64)
- unigrams = unigrams ** distortion
-
- def fun(i):
- if i and i % 100 == 0:
- print(i)
- if num_repeats[i] == 0:
- return []
- pos = torch.flatten(true_classes[i, :])
- pos = pos[pos >= 0]
- w = unigrams.clone().detach()
- w[pos] = 0
- sampler = torch.utils.data.WeightedRandomSampler(w,
- num_repeats[i].item(), replacement=False)
- res = list(sampler)
- return res
-
- with multiprocessing.pool.ThreadPool() as p:
- res = p.map(fun, range(len(num_repeats)))
- res = reduce(list.__add__, res, [])
-
- return torch.tensor(res)
-
-
- def fixed_unigram_candidate_sampler(
- true_classes: torch.Tensor,
- num_repeats: torch.Tensor,
- unigrams: torch.Tensor,
- distortion: float = 1.) -> torch.Tensor:
-
- if len(true_classes.shape) != 2:
- raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
-
- if len(num_repeats.shape) != 1:
- raise ValueError('num_repeats must be 1D')
-
- if torch.any((unigrams > 0).sum() - \
- (true_classes >= 0).sum(dim=1) < \
- num_repeats):
- raise ValueError('Not enough classes to choose from')
-
- num_rows = true_classes.shape[0]
- print('true_classes.shape:', true_classes.shape)
- # unigrams = np.array(unigrams)
- if distortion != 1.:
- unigrams = unigrams.to(torch.float64) ** distortion
- print('unigrams:', unigrams)
-
- indices = torch.arange(num_rows)
- indices = torch.repeat_interleave(indices, num_repeats)
- indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
- indices.view(-1, 1) ], dim=1)
-
- num_samples = len(indices)
- result = torch.zeros(num_samples, dtype=torch.long)
- print('num_rows:', num_rows, 'num_samples:', num_samples)
-
- while len(indices) > 0:
- print('len(indices):', len(indices))
- print('indices:', indices)
- sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
- candidates = torch.tensor(list(sampler))
- candidates = candidates.view(len(indices), 1)
- print('candidates:', candidates)
- print('true_classes:', true_classes[indices[:, 1], :])
- result[indices[:, 0]] = candidates.transpose(0, 1)
- print('result:', result)
- mask = (candidates == true_classes[indices[:, 1], :])
- mask = mask.sum(1).to(torch.bool)
- # append_true_classes = torch.full(( len(true_classes), ), -1)
- # append_true_classes[~mask] = torch.flatten(candidates)[~mask]
- # true_classes = torch.cat([
- # append_true_classes.view(-1, 1),
- # true_classes
- # ], dim=1)
- print('mask:', mask)
- indices = indices[mask]
- # result[indices] = 0
- return result
-
-
- def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
- Tuple[torch.Tensor, torch.Tensor]:
-
- if adj_mat.is_sparse:
- adj_mat = adj_mat.coalesce()
- degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
- device=adj_mat.device)
- degrees = degrees.index_add(0, adj_mat.indices()[1],
- torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
- device=adj_mat.device))
- edges_pos = adj_mat.indices().transpose(0, 1)
- else:
- degrees = adj_mat.sum(0)
- edges_pos = torch.nonzero(adj_mat, as_tuple=False)
- return edges_pos, degrees
-
-
- def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- indices = adj_mat.indices()
- row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
- #print('indices[0]:', indices[0], count[indices[0]])
- row_count = row_count.index_add(0, indices[0],
- torch.ones(indices.shape[1], dtype=torch.long))
- #print('count:', count)
- max_true_classes = torch.max(row_count).item()
- #print('max_true_classes:', max_true_classes)
- true_classes = torch.full((adj_mat.shape[0], max_true_classes),
- -1, dtype=torch.long)
-
-
- # inv = torch.unique(indices[0], return_inverse=True)
-
- # indices = indices.copy()
- # true_classes[indices[0], 0] = indices[1]
- t = time.time()
- cc = cumcount(indices[0].cpu().numpy())
- print('cumcount() took:', time.time() - t)
- cc = torch.tensor(cc)
- t = time.time()
- true_classes[indices[0], cc] = indices[1]
- print('assignment took:', time.time() - t)
-
- ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
- for i in range(indices.shape[1]):
- # print('looping...')
- row = indices[0, i]
- col = indices[1, i]
- #print('row:', row, 'col:', col, 'count[row]:', count[row])
- true_classes[row, count[row]] = col
- count[row] += 1 '''
-
- # t = time.time()
- # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
- # print('repeat_interleave() took:', time.time() - t)
-
- return true_classes, row_count
-
-
- def negative_sample_adj_mat(adj_mat: torch.Tensor,
- remove_diagonal: bool=False) -> torch.Tensor:
-
- if not isinstance(adj_mat, torch.Tensor):
- raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
-
- edges_pos, degrees = get_edges_and_degrees(adj_mat)
- degrees = degrees.to(torch.float32) + 1.0 / torch.numel(adj_mat)
-
- true_classes, row_count = get_true_classes(adj_mat)
- if remove_diagonal:
- true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1),
- true_classes ], dim=1)
- # true_classes = edges_pos[:, 1].view(-1, 1)
- # print('true_classes:', true_classes)
-
- neg_neighbors = fixed_unigram_candidate_sampler(
- true_classes, row_count, degrees, 0.75).to(adj_mat.device)
-
- print('neg_neighbors:', neg_neighbors)
-
- pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)),
- row_count)
-
- edges_neg = torch.cat([ pos_vertices.view(-1, 1),
- neg_neighbors.view(-1, 1) ], 1)
-
- adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
- values=torch.ones(len(edges_neg)), size=adj_mat.shape,
- dtype=adj_mat.dtype, device=adj_mat.device)
-
- adj_mat_neg = adj_mat_neg.coalesce()
- indices = adj_mat_neg.indices()
- adj_mat_neg = torch.sparse_coo_tensor(indices,
- torch.ones(indices.shape[1]), adj_mat.shape,
- dtype=adj_mat.dtype, device=adj_mat.device)
-
- adj_mat_neg = adj_mat_neg.coalesce()
-
- return adj_mat_neg
-
-
- def negative_sample_data(data: Data) -> Data:
- new_edge_types = {}
- res = Data(target_value=0)
- for vt in data.vertex_types:
- res.add_vertex_type(vt.name, vt.count)
- for key, et in data.edge_types.items():
- print('key:', key)
- adjacency_matrices_neg = []
- for adj_mat in et.adjacency_matrices:
- remove_diagonal = True \
- if et.vertex_type_row == et.vertex_type_column \
- else False
- adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal)
- adjacency_matrices_neg.append(adj_mat_neg)
- res.add_edge_type(et.name,
- et.vertex_type_row, et.vertex_type_column,
- adjacency_matrices_neg, et.decoder_factory)
- #new_et = EdgeType(et.name, et.vertex_type_row,
- # et.vertex_type_column, adjacency_matrices_neg,
- # et.decoder_factory, et.total_connectivity)
- #new_edge_types[key] = new_et
- #res = Data(data.vertex_types, new_edge_types)
- return res
-
-
- def merge_data(pos_data: Data, neg_data: Data) -> Data:
- assert isinstance(pos_data, Data)
- assert isinstance(neg_data, Data)
-
- res = PosNegData()
-
- for vt in pos_data.vertex_types:
- res.add_vertex_type(vt.name, vt.count)
-
- for key, pos_et in pos_data.edge_types.items():
- neg_et = neg_data.edge_types[key]
- res.add_edge_type(pos_et.name,
- pos_et.vertex_type_row, pos_et.vertex_type_column,
- pos_et.adjacency_matrices, neg_et.adjacency_matrices,
- pos_et.decoder_factory)
|