IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add test_merge_pos_neg_batch_01/02().

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
2260f55399
共有 2 个文件被更改,包括 71 次插入1 次删除
  1. +5
    -1
      src/triacontagon/loop.py
  2. +66
    -0
      tests/triacontagon/test_loop.py

+ 5
- 1
src/triacontagon/loop.py 查看文件

@@ -1,7 +1,11 @@
from .model import Model
from .model import Model, \
TrainingBatch
from .batch import Batcher
from .sampling import negative_sample_data
from .data import Data
import torch
from typing import List, \
Callable
def _merge_pos_neg_batches(pos_batch, neg_batch):


+ 66
- 0
tests/triacontagon/test_loop.py 查看文件

@@ -0,0 +1,66 @@
from triacontagon.loop import _merge_pos_neg_batches
from triacontagon.model import TrainingBatch
import torch
import pytest
def test_merge_pos_neg_batches_01():
b_1 = TrainingBatch(0, 0, 0, torch.tensor([
[0, 1],
[2, 3],
[4, 5],
[5, 6]
]), torch.ones(4))
b_2 = TrainingBatch(0, 0, 0, torch.tensor([
[1, 6],
[3, 5],
[5, 2],
[4, 1]
]), torch.zeros(4))
b = _merge_pos_neg_batches(b_1, b_2)
assert b.vertex_type_row == 0
assert b.vertex_type_column == 0
assert b.relation_type_index == 0
assert torch.all(b.edges == torch.tensor([
[0, 1],
[2, 3],
[4, 5],
[5, 6],
[1, 6],
[3, 5],
[5, 2],
[4, 1]
]))
assert torch.all(b.target_values == \
torch.cat([ torch.ones(4), torch.zeros(4) ]))
def test_merge_pos_neg_batches_02():
b_1 = TrainingBatch(0, 1, 0, torch.tensor([
[0, 1],
[2, 3],
[4, 5],
[5, 6]
]), torch.ones(4))
b_2 = TrainingBatch(0, 0, 0, torch.tensor([
[1, 6],
[3, 5],
[5, 2],
[4, 1]
]), torch.zeros(4))
print(b_1)
with pytest.raises(AssertionError):
_ = _merge_pos_neg_batches(b_1, b_2)
b_1.vertex_type_row, b_1.vertex_type_column = \
b_1.vertex_type_column, b_1.vertex_type_row
print(b_1)
with pytest.raises(AssertionError):
_ = _merge_pos_neg_batches(b_1, b_2)
b_1.vertex_type_row, b_1.relation_type_index = \
b_1.relation_type_index, b_1.vertex_type_row
print(b_1)
with pytest.raises(AssertionError):
_ = _merge_pos_neg_batches(b_1, b_2)

正在加载...
取消
保存