IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_decode_layer_parameter_count_[01-03]().

master
Stanislaw Adaszewski 4 years ago
parent
commit
6c8fdb7091
1 changed files with 53 additions and 0 deletions
  1. +53
    -0
      tests/icosagon/test_declayer.py

+ 53
- 0
tests/icosagon/test_declayer.py View File

@@ -232,3 +232,56 @@ def test_decode_layer_05():
# assert isinstance(rel_pred.edges_neg, TrainValTest)
# assert isinstance(rel_pred.edges_back_pos, TrainValTest)
# assert isinstance(rel_pred.edges_back_neg, TrainValTest)
def test_decode_layer_parameter_count_01():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 2
def test_decode_layer_parameter_count_02():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 3
def test_decode_layer_parameter_count_03():
d = Data()
d.add_node_type('Dummy', 100)
for _ in range(2):
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 6

Loading…
Cancel
Save