IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Work on fixed_unigram_candidate_sampler_new().

master
Stanislaw Adaszewski 4年前
コミット
670237c3f8
3個のファイルの変更11行の追加6行の削除
  1. +1
    -0
      src/triacontagon/cumcount.py
  2. +3
    -3
      src/triacontagon/sampling.py
  3. +7
    -3
      tests/triacontagon/test_sampling.py

+ 1
- 0
src/triacontagon/cumcount.py ファイルの表示

@@ -9,6 +9,7 @@ def dfill(a):
torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
torch.tensor([n])
])
print('b:',b)
res = torch.arange(n)[b[:-1]]
res = torch.repeat_interleave(res, b[1:] - b[:-1])
return res


+ 3
- 3
src/triacontagon/sampling.py ファイルの表示

@@ -21,7 +21,7 @@ from itertools import product, \
from functools import reduce
def fixed_unigram_candidate_sampler_new(
def fixed_unigram_candidate_sampler(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,
@@ -50,7 +50,7 @@ def fixed_unigram_candidate_sampler_new(
dtype=true_classes.dtype)
], dim=1)
indices = torch.repeat_interleave(torch.arange(len(unigrams)), num_repeats)
indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats)
indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
indices.view(-1, 1) ], dim=1)
@@ -135,7 +135,7 @@ def fixed_unigram_candidate_sampler_slow(
return torch.tensor(res)
def fixed_unigram_candidate_sampler(
def fixed_unigram_candidate_sampler_old(
true_classes: torch.Tensor,
num_repeats: torch.Tensor,
unigrams: torch.Tensor,


+ 7
- 3
tests/triacontagon/test_sampling.py ファイルの表示

@@ -3,11 +3,12 @@ from triacontagon.sampling import fixed_unigram_candidate_sampler, \
get_true_classes, \
negative_sample_adj_mat, \
negative_sample_data, \
get_edges_and_degrees, \
fixed_unigram_candidate_sampler_new
get_edges_and_degrees
import triacontagon.sampling
from triacontagon.decode import dedicom_decoder
import torch
import time
import pytest
def test_fixed_unigram_candidate_sampler_01():
@@ -41,6 +42,7 @@ def test_fixed_unigram_candidate_sampler_02():
print('row_count:', row_count)
edges_pos, degrees = get_edges_and_degrees(bar_foo)
print('degrees:', degrees)
res = fixed_unigram_candidate_sampler(true_classes, row_count,
degrees, 0.75)
@@ -117,7 +119,9 @@ def test_negative_sample_data_01():
def test_fixed_unigram_candidate_sampler_new_01():
x = (torch.rand((10, 10)) < .1).to(torch.float32).to_sparse()
if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling):
pytest.skip('fixed_unigram_candidate_sampler_new not found')
x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse()
true_classes, row_count = get_true_classes(x)
edges, degrees = get_edges_and_degrees(x)
# import pdb


読み込み中…
キャンセル
保存