IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

Add test_decode_layer_parameter_count_[01-03]().

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
6c8fdb7091
1 geänderte Dateien mit 53 neuen und 0 gelöschten Zeilen
  1. +53
    -0
      tests/icosagon/test_declayer.py

+ 53
- 0
tests/icosagon/test_declayer.py Datei anzeigen

@@ -232,3 +232,56 @@ def test_decode_layer_05():
# assert isinstance(rel_pred.edges_neg, TrainValTest)
# assert isinstance(rel_pred.edges_back_pos, TrainValTest)
# assert isinstance(rel_pred.edges_back_neg, TrainValTest)
def test_decode_layer_parameter_count_01():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 2
def test_decode_layer_parameter_count_02():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 3
def test_decode_layer_parameter_count_03():
d = Data()
d.add_node_type('Dummy', 100)
for _ in range(2):
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
fam.add_relation_type('Dummy Relation 2',
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
activation=lambda x: x)
assert len(list(dec.parameters())) == 6

Laden…
Abbrechen
Speichern