|
|
@@ -84,3 +84,78 @@ def test_inner_product_decoder_01(): |
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert torch.all(res_1[i] == res_2[i])
|
|
|
|
|
|
|
|
|
|
|
|
def test_is_dedicom_not_symmetric_01():
|
|
|
|
repr_1 = torch.rand(20, 32)
|
|
|
|
repr_2 = torch.rand(20, 32)
|
|
|
|
dec = DEDICOMDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
|
|
|
|
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res_1, list)
|
|
|
|
assert isinstance(res_2, list)
|
|
|
|
|
|
|
|
assert len(res_1) == len(res_2)
|
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
|
|
|
|
|
|
|
|
|
|
|
|
def test_is_dist_mult_symmetric_01():
|
|
|
|
repr_1 = torch.rand(20, 32)
|
|
|
|
repr_2 = torch.rand(20, 32)
|
|
|
|
dec = DistMultDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
|
|
|
|
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res_1, list)
|
|
|
|
assert isinstance(res_2, list)
|
|
|
|
|
|
|
|
assert len(res_1) == len(res_2)
|
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert torch.all(res_1[i] - res_2[i] < 0.000001)
|
|
|
|
|
|
|
|
|
|
|
|
def test_is_bilinear_not_symmetric_01():
|
|
|
|
repr_1 = torch.rand(20, 32)
|
|
|
|
repr_2 = torch.rand(20, 32)
|
|
|
|
dec = BilinearDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
|
|
|
|
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
|
|
|
|
|
|
|
|
assert isinstance(res_1, list)
|
|
|
|
assert isinstance(res_2, list)
|
|
|
|
|
|
|
|
assert len(res_1) == len(res_2)
|
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert not torch.all(res_1[i] - res_2[i] < 0.000001)
|
|
|
|
|
|
|
|
|
|
|
|
def test_is_inner_product_symmetric_01():
|
|
|
|
repr_1 = torch.rand(20, 32)
|
|
|
|
repr_2 = torch.rand(20, 32)
|
|
|
|
dec = InnerProductDecoder(32, 7, keep_prob=1.,
|
|
|
|
activation=torch.sigmoid)
|
|
|
|
|
|
|
|
res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
|
|
|
|
res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(res_1, list)
|
|
|
|
assert isinstance(res_2, list)
|
|
|
|
|
|
|
|
assert len(res_1) == len(res_2)
|
|
|
|
|
|
|
|
for i in range(len(res_1)):
|
|
|
|
assert torch.all(res_1[i] - res_2[i] < 0.000001)
|