| @@ -84,3 +84,78 @@ def test_inner_product_decoder_01(): | |||
| for i in range(len(res_1)): | |||
| assert torch.all(res_1[i] == res_2[i]) | |||
| def test_is_dedicom_not_symmetric_01(): | |||
| repr_1 = torch.rand(20, 32) | |||
| repr_2 = torch.rand(20, 32) | |||
| dec = DEDICOMDecoder(32, 7, keep_prob=1., | |||
| activation=torch.sigmoid) | |||
| res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | |||
| res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | |||
| assert isinstance(res_1, list) | |||
| assert isinstance(res_2, list) | |||
| assert len(res_1) == len(res_2) | |||
| for i in range(len(res_1)): | |||
| assert not torch.all(res_1[i] - res_2[i] < 0.000001) | |||
| def test_is_dist_mult_symmetric_01(): | |||
| repr_1 = torch.rand(20, 32) | |||
| repr_2 = torch.rand(20, 32) | |||
| dec = DistMultDecoder(32, 7, keep_prob=1., | |||
| activation=torch.sigmoid) | |||
| res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | |||
| res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | |||
| assert isinstance(res_1, list) | |||
| assert isinstance(res_2, list) | |||
| assert len(res_1) == len(res_2) | |||
| for i in range(len(res_1)): | |||
| assert torch.all(res_1[i] - res_2[i] < 0.000001) | |||
| def test_is_bilinear_not_symmetric_01(): | |||
| repr_1 = torch.rand(20, 32) | |||
| repr_2 = torch.rand(20, 32) | |||
| dec = BilinearDecoder(32, 7, keep_prob=1., | |||
| activation=torch.sigmoid) | |||
| res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | |||
| res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | |||
| assert isinstance(res_1, list) | |||
| assert isinstance(res_2, list) | |||
| assert len(res_1) == len(res_2) | |||
| for i in range(len(res_1)): | |||
| assert not torch.all(res_1[i] - res_2[i] < 0.000001) | |||
| def test_is_inner_product_symmetric_01(): | |||
| repr_1 = torch.rand(20, 32) | |||
| repr_2 = torch.rand(20, 32) | |||
| dec = InnerProductDecoder(32, 7, keep_prob=1., | |||
| activation=torch.sigmoid) | |||
| res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ] | |||
| res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ] | |||
| assert isinstance(res_1, list) | |||
| assert isinstance(res_2, list) | |||
| assert len(res_1) == len(res_2) | |||
| for i in range(len(res_1)): | |||
| assert torch.all(res_1[i] - res_2[i] < 0.000001) | |||