IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
소스 검색

Add test_decode_layer_parameter_count_[01-03]().

master
Stanislaw Adaszewski 3 년 전
부모
커밋
6c8fdb7091
1개의 변경된 파일53개의 추가작업 그리고 0개의 파일을 삭제
  1. +53
    -0
      tests/icosagon/test_declayer.py

+ 53
- 0
tests/icosagon/test_declayer.py 파일 보기

@@ -232,3 +232,56 @@ def test_decode_layer_05():
# assert isinstance(rel_pred.edges_neg, TrainValTest)
# assert isinstance(rel_pred.edges_back_pos, TrainValTest)
# assert isinstance(rel_pred.edges_back_neg, TrainValTest)
def test_decode_layer_parameter_count_01():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 2
def test_decode_layer_parameter_count_02():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 3
def test_decode_layer_parameter_count_03():
d = Data()
d.add_node_type('Dummy', 100)
for _ in range(2):
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 6

불러오는 중...
취소
저장