|
|
@@ -210,3 +210,31 @@ def test_empty_inner_product_decoder_01(): |
|
|
|
|
|
|
|
for i in range(len(res)):
|
|
|
|
assert res[i].shape == (0,)
|
|
|
|
|
|
|
|
|
|
|
|
def test_dedicom_decoder_parameter_count_01():
|
|
|
|
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 8
|
|
|
|
|
|
|
|
|
|
|
|
def test_dist_mult_decoder_parameter_count_01():
|
|
|
|
dec = DistMultDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 7
|
|
|
|
|
|
|
|
|
|
|
|
def test_bilinear_decoder_parameter_count_01():
|
|
|
|
dec = BilinearDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 7
|
|
|
|
|
|
|
|
|
|
|
|
def test_inner_product_decoder_parameter_count_01():
|
|
|
|
dec = InnerProductDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
assert len(list(dec.parameters())) == 0
|