|
@@ -232,3 +232,56 @@ def test_decode_layer_05(): |
|
|
# assert isinstance(rel_pred.edges_neg, TrainValTest)
|
|
|
# assert isinstance(rel_pred.edges_neg, TrainValTest)
|
|
|
# assert isinstance(rel_pred.edges_back_pos, TrainValTest)
|
|
|
# assert isinstance(rel_pred.edges_back_pos, TrainValTest)
|
|
|
# assert isinstance(rel_pred.edges_back_neg, TrainValTest)
|
|
|
# assert isinstance(rel_pred.edges_back_neg, TrainValTest)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_decode_layer_parameter_count_01():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 100)
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
|
|
fam.add_relation_type('Dummy Relation 1',
|
|
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
|
|
|
|
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
|
|
|
|
|
|
activation=lambda x: x)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_decode_layer_parameter_count_02():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 100)
|
|
|
|
|
|
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
|
|
fam.add_relation_type('Dummy Relation 1',
|
|
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
fam.add_relation_type('Dummy Relation 2',
|
|
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
|
|
|
|
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
|
|
|
|
|
|
activation=lambda x: x)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_decode_layer_parameter_count_03():
|
|
|
|
|
|
d = Data()
|
|
|
|
|
|
d.add_node_type('Dummy', 100)
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range(2):
|
|
|
|
|
|
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
|
|
|
|
|
|
fam.add_relation_type('Dummy Relation 1',
|
|
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
fam.add_relation_type('Dummy Relation 2',
|
|
|
|
|
|
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
|
|
|
|
|
|
|
|
|
|
|
|
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
|
|
|
|
|
|
|
|
|
|
|
|
dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
|
|
|
|
|
|
activation=lambda x: x)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 6
|