|
|
@@ -158,3 +158,55 @@ def test_is_inner_product_symmetric_01(): |
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert torch.all(res_1[i] - res_2[i] < 0.000001)
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_dedicom_decoder_01():
|
|
|
|
repr_ = torch.rand(0, 32)
|
|
|
|
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res = [ dec(repr_, repr_, k) for k in range(7) ]
|
|
|
|
|
|
|
|
assert isinstance(res, list)
|
|
|
|
|
|
|
|
for i in range(len(res)):
|
|
|
|
assert res[i].shape == (0,)
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_dist_mult_decoder_01():
|
|
|
|
repr_ = torch.rand(0, 32)
|
|
|
|
dec = DistMultDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res = [ dec(repr_, repr_, k) for k in range(7) ]
|
|
|
|
|
|
|
|
assert isinstance(res, list)
|
|
|
|
|
|
|
|
for i in range(len(res)):
|
|
|
|
assert res[i].shape == (0,)
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_bilinear_decoder_01():
|
|
|
|
repr_ = torch.rand(0, 32)
|
|
|
|
dec = BilinearDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res = [ dec(repr_, repr_, k) for k in range(7) ]
|
|
|
|
|
|
|
|
assert isinstance(res, list)
|
|
|
|
|
|
|
|
for i in range(len(res)):
|
|
|
|
assert res[i].shape == (0,)
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_inner_product_decoder_01():
|
|
|
|
repr_ = torch.rand(0, 32)
|
|
|
|
dec = InnerProductDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res = [ dec(repr_, repr_, k) for k in range(7) ]
|
|
|
|
|
|
|
|
assert isinstance(res, list)
|
|
|
|
|
|
|
|
for i in range(len(res)):
|
|
|
|
assert res[i].shape == (0,)
|