| @@ -158,3 +158,55 @@ def test_is_inner_product_symmetric_01(): | |||||
| for i in range(len(res_1)): | for i in range(len(res_1)): | ||||
| assert torch.all(res_1[i] - res_2[i] < 0.000001) | assert torch.all(res_1[i] - res_2[i] < 0.000001) | ||||
| def test_empty_dedicom_decoder_01(): | |||||
| repr_ = torch.rand(0, 32) | |||||
| dec = DEDICOMDecoder(32, 7, keep_prob=1., | |||||
| activation=torch.sigmoid) | |||||
| res = [ dec(repr_, repr_, k) for k in range(7) ] | |||||
| assert isinstance(res, list) | |||||
| for i in range(len(res)): | |||||
| assert res[i].shape == (0,) | |||||
| def test_empty_dist_mult_decoder_01(): | |||||
| repr_ = torch.rand(0, 32) | |||||
| dec = DistMultDecoder(32, 7, keep_prob=1., | |||||
| activation=torch.sigmoid) | |||||
| res = [ dec(repr_, repr_, k) for k in range(7) ] | |||||
| assert isinstance(res, list) | |||||
| for i in range(len(res)): | |||||
| assert res[i].shape == (0,) | |||||
| def test_empty_bilinear_decoder_01(): | |||||
| repr_ = torch.rand(0, 32) | |||||
| dec = BilinearDecoder(32, 7, keep_prob=1., | |||||
| activation=torch.sigmoid) | |||||
| res = [ dec(repr_, repr_, k) for k in range(7) ] | |||||
| assert isinstance(res, list) | |||||
| for i in range(len(res)): | |||||
| assert res[i].shape == (0,) | |||||
| def test_empty_inner_product_decoder_01(): | |||||
| repr_ = torch.rand(0, 32) | |||||
| dec = InnerProductDecoder(32, 7, keep_prob=1., | |||||
| activation=torch.sigmoid) | |||||
| res = [ dec(repr_, repr_, k) for k in range(7) ] | |||||
| assert isinstance(res, list) | |||||
| for i in range(len(res)): | |||||
| assert res[i].shape == (0,) | |||||