IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add trainprep and negative_sample_(adj_mat|data)().

master
Stanislaw Adaszewski 3 years ago
parent
commit
e557b18762
3 changed files with 168 additions and 2 deletions
  1. +71
    -2
      src/triacontagon/sampling.py
  2. +59
    -0
      src/triacontagon/trainprep.py
  3. +38
    -0
      tests/triacontagon/test_sampling.py

+ 71
- 2
src/triacontagon/sampling.py View File

@@ -8,7 +8,10 @@ import numpy as np
import torch
import torch.utils.data
from typing import List, \
Union
Union, \
Tuple
from .data import Data, \
EdgeType
def fixed_unigram_candidate_sampler(
@@ -24,7 +27,7 @@ def fixed_unigram_candidate_sampler(
if len(true_classes.shape) != 2:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
num_samples = true_classes.shape[0]
unigrams = np.array(unigrams)
if distortion != 1.:
@@ -40,8 +43,74 @@ def fixed_unigram_candidate_sampler(
# print('candidates:', candidates)
# print('true_classes:', true_classes[indices, :])
result[indices] = candidates.T
# print('result:', result)
mask = (candidates == true_classes[indices, :])
mask = mask.sum(1).astype(np.bool)
# print('mask:', mask)
indices = indices[mask]
# result[indices] = 0
return torch.tensor(result)
def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
if adj_mat.is_sparse:
adj_mat = adj_mat.coalesce()
degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
device=adj_mat.device)
degrees = degrees.index_add(0, adj_mat.indices()[1],
torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
device=adj_mat.device))
edges_pos = adj_mat.indices().transpose(0, 1)
else:
degrees = adj_mat.sum(0)
edges_pos = torch.nonzero(adj_mat, as_tuple=False)
return edges_pos, degrees
def negative_sample_adj_mat(adj_mat: torch.Tensor) -> torch.Tensor:
if not isinstance(adj_mat, torch.Tensor):
raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
edges_pos, degrees = get_edges_and_degrees(adj_mat)
neg_neighbors = fixed_unigram_candidate_sampler(
edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
edges_neg = torch.cat([ edges_pos[:, 0].view(-1, 1),
neg_neighbors.view(-1, 1) ], 1)
adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
values=torch.ones(len(edges_neg)), size=adj_mat.shape,
dtype=adj_mat.dtype, device=adj_mat.device)
adj_mat_neg = adj_mat_neg.coalesce()
indices = adj_mat_neg.indices()
adj_mat_neg = torch.sparse_coo_tensor(indices,
torch.ones(indices.shape[1]), adj_mat.shape,
dtype=adj_mat.dtype, device=adj_mat.device)
adj_mat_neg = adj_mat_neg.coalesce()
return adj_mat_neg
def negative_sample_data(data: Data) -> Data:
new_edge_types = {}
res = Data()
for vt in data.vertex_types:
res.add_vertex_type(vt.name, vt.count)
for key, et in data.edge_types.items():
adjacency_matrices_neg = []
for adj_mat in et.adjacency_matrices:
adj_mat_neg = negative_sample_adj_mat(adj_mat)
adjacency_matrices_neg.append(adj_mat_neg)
res.add_edge_type(et.name,
et.vertex_type_row, et.vertex_type_column,
adjacency_matrices_neg, et.decoder_factory)
#new_et = EdgeType(et.name, et.vertex_type_row,
# et.vertex_type_column, adjacency_matrices_neg,
# et.decoder_factory, et.total_connectivity)
#new_edge_types[key] = new_et
#res = Data(data.vertex_types, new_edge_types)
return res

+ 59
- 0
src/triacontagon/trainprep.py View File

@@ -0,0 +1,59 @@
from .data import Data, \
TrainingBatch, \
EdgeType
from typing import Tuple
from .util import _sparse_coo_tensor
def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
indices = adj_mat.indices()
values = adj_mat.values()
order = torch.randperm(indices.shape[1])
indices = indices[:, order]
values = values[order]
ofs = 0
res = []
for r in ratios:
cnt = r * len(values)
ind = indices[:, ofs:ofs+cnt]
val = values[ofs:ofs+cnt]
res.append(_sparse_coo_tensor(ind, val, adj_mat.shape))
ofs += cnt
return res
def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
res = [ [] for _ in range(len(et.adjacency_matrices)) ]
for adj_mat in et.adjacency_matrices:
for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)):
res[i].append(new_adj_mat)
return res
def split_data(data: Data,
ratios: List[float]):
if not isinstance(data, Data):
raise TypeError('data must be an instance of Data')
ratios = list(ratios)
if sum(ratios) != 1:
raise ValueError('ratios must sum to 1')
res = [ {} for _ in range(len(ratios)) ]
for key, et in data.edge_types:
for i, new_et in enumerate(split_edge_type(et, ratios)):
res[i][key] = new_et
res = [ Data(data.vertex_types, new_edge_types) \
for new_edge_types in res ]
return res

+ 38
- 0
tests/triacontagon/test_sampling.py View File

@@ -0,0 +1,38 @@
from triacontagon.data import Data
from triacontagon.sampling import negative_sample_adj_mat, \
negative_sample_data
from triacontagon.decode import dedicom_decoder
import torch
def test_negative_sample_adj_mat_01():
adj_mat = torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[0, 1, 0, 0, 0]
])
print('adj_mat:', adj_mat)
adj_mat_neg = negative_sample_adj_mat(adj_mat)
print('adj_mat_neg:', adj_mat_neg.to_dense())
def test_negative_sample_data_01():
d = Data()
d.add_vertex_type('Gene', 5)
d.add_edge_type('Gene-Gene', 0, 0, [
torch.tensor([
[0, 1, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 1, 0, 0, 0],
[0, 0, 1, 0, 1],
[0, 1, 0, 0, 0]
], dtype=torch.float).to_sparse()
], dedicom_decoder)
d_neg = negative_sample_data(d)

Loading…
Cancel
Save