|  |  | @@ -17,41 +17,41 @@ import time | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes: Union[np.array, torch.Tensor], | 
		
	
		
			
			|  |  |  | unigrams: List[Union[int, float]], | 
		
	
		
			
			|  |  |  | true_classes: torch.Tensor, | 
		
	
		
			
			|  |  |  | num_repeats: torch.Tensor, | 
		
	
		
			
			|  |  |  | unigrams: torch.Tensor, | 
		
	
		
			
			|  |  |  | distortion: float = 1.): | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if isinstance(true_classes, torch.Tensor): | 
		
	
		
			
			|  |  |  | true_classes = true_classes.detach().cpu().numpy() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if isinstance(unigrams, torch.Tensor): | 
		
	
		
			
			|  |  |  | unigrams = unigrams.detach().cpu().numpy() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | if len(true_classes.shape) != 2: | 
		
	
		
			
			|  |  |  | raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | num_samples = true_classes.shape[0] | 
		
	
		
			
			|  |  |  | unigrams = np.array(unigrams) | 
		
	
		
			
			|  |  |  | if len(num_repeats.shape) != 1: | 
		
	
		
			
			|  |  |  | raise ValueError('num_repeats must be 1D') | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | num_rows = true_classes.shape[0] | 
		
	
		
			
			|  |  |  | # unigrams = np.array(unigrams) | 
		
	
		
			
			|  |  |  | if distortion != 1.: | 
		
	
		
			
			|  |  |  | unigrams = unigrams.astype(np.float64) ** distortion | 
		
	
		
			
			|  |  |  | unigrams = unigrams.to(torch.float64) ** distortion | 
		
	
		
			
			|  |  |  | # print('unigrams:', unigrams) | 
		
	
		
			
			|  |  |  | indices = np.arange(num_samples) | 
		
	
		
			
			|  |  |  | result = np.zeros(num_samples, dtype=np.int64) | 
		
	
		
			
			|  |  |  | indices = torch.arange(num_rows) | 
		
	
		
			
			|  |  |  | indices = torch.repeat_interleave(indices, num_repeats) | 
		
	
		
			
			|  |  |  | num_samples = len(indices) | 
		
	
		
			
			|  |  |  | result = torch.zeros(num_samples, dtype=torch.long) | 
		
	
		
			
			|  |  |  | while len(indices) > 0: | 
		
	
		
			
			|  |  |  | # print('len(indices):', len(indices)) | 
		
	
		
			
			|  |  |  | sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices)) | 
		
	
		
			
			|  |  |  | candidates = np.array(list(sampler)) | 
		
	
		
			
			|  |  |  | candidates = np.reshape(candidates, (len(indices), 1)) | 
		
	
		
			
			|  |  |  | candidates = torch.tensor(list(sampler)) | 
		
	
		
			
			|  |  |  | candidates = candidates.view(len(indices), 1) | 
		
	
		
			
			|  |  |  | # print('candidates:', candidates) | 
		
	
		
			
			|  |  |  | # print('true_classes:', true_classes[indices, :]) | 
		
	
		
			
			|  |  |  | result[indices] = candidates.T | 
		
	
		
			
			|  |  |  | result[indices] = candidates.transpose(0, 1) | 
		
	
		
			
			|  |  |  | # print('result:', result) | 
		
	
		
			
			|  |  |  | mask = (candidates == true_classes[indices, :]) | 
		
	
		
			
			|  |  |  | mask = mask.sum(1).astype(np.bool) | 
		
	
		
			
			|  |  |  | mask = mask.sum(1).to(torch.bool) | 
		
	
		
			
			|  |  |  | # print('mask:', mask) | 
		
	
		
			
			|  |  |  | indices = indices[mask] | 
		
	
		
			
			|  |  |  | # result[indices] = 0 | 
		
	
		
			
			|  |  |  | return torch.tensor(result) | 
		
	
		
			
			|  |  |  | return result | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ | 
		
	
	
		
			
				|  |  | @@ -71,7 +71,7 @@ def get_edges_and_degrees(adj_mat: torch.Tensor) -> \ | 
		
	
		
			
			|  |  |  | return edges_pos, degrees | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | 
		
	
		
			
			|  |  |  | indices = adj_mat.indices() | 
		
	
		
			
			|  |  |  | row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long) | 
		
	
		
			
			|  |  |  | #print('indices[0]:', indices[0], count[indices[0]]) | 
		
	
	
		
			
				|  |  | @@ -105,11 +105,11 @@ def get_true_classes(adj_mat: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  | true_classes[row, count[row]] = col | 
		
	
		
			
			|  |  |  | count[row] += 1 ''' | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | t = time.time() | 
		
	
		
			
			|  |  |  | true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) | 
		
	
		
			
			|  |  |  | print('repeat_interleave() took:', time.time() - t) | 
		
	
		
			
			|  |  |  | # t = time.time() | 
		
	
		
			
			|  |  |  | # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0) | 
		
	
		
			
			|  |  |  | # print('repeat_interleave() took:', time.time() - t) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | return true_classes | 
		
	
		
			
			|  |  |  | return true_classes, row_count | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | 
		
	
	
		
			
				|  |  | @@ -118,12 +118,12 @@ def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor: | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | edges_pos, degrees = get_edges_and_degrees(adj_mat) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | true_classes = get_true_classes(adj_mat) | 
		
	
		
			
			|  |  |  | true_classes, row_count = get_true_classes(adj_mat) | 
		
	
		
			
			|  |  |  | # true_classes = edges_pos[:, 1].view(-1, 1) | 
		
	
		
			
			|  |  |  | # print('true_classes:', true_classes) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | neg_neighbors = fixed_unigram_candidate_sampler( | 
		
	
		
			
			|  |  |  | true_classes, degrees, 0.75).to(adj_mat.device) | 
		
	
		
			
			|  |  |  | true_classes, row_count, degrees, 0.75).to(adj_mat.device) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | print('neg_neighbors:', neg_neighbors) | 
		
	
		
			
			|  |  |  |  | 
		
	
	
		
			
				|  |  | 
 |