|
@@ -5,25 +5,14 @@ import tensorflow as tf |
|
|
import torch
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dedicom():
|
|
|
|
|
|
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
|
|
|
|
|
|
num_relation_types=7)
|
|
|
|
|
|
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
|
|
|
|
|
|
edge_type=(0, 0))
|
|
|
|
|
|
|
|
|
|
|
|
dedicom_tf.vars['global_interaction'] = \
|
|
|
|
|
|
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
|
|
|
|
|
|
for i in range(dedicom_tf.num_types):
|
|
|
|
|
|
dedicom_tf.vars['local_variation_%d' % i] = \
|
|
|
|
|
|
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _common(decoder_torch, decoder_tf):
|
|
|
inputs = np.random.rand(20, 10).astype(np.float32)
|
|
|
inputs = np.random.rand(20, 10).astype(np.float32)
|
|
|
inputs_torch = torch.tensor(inputs)
|
|
|
inputs_torch = torch.tensor(inputs)
|
|
|
inputs_tf = {
|
|
|
inputs_tf = {
|
|
|
0: tf.convert_to_tensor(inputs)
|
|
|
0: tf.convert_to_tensor(inputs)
|
|
|
}
|
|
|
}
|
|
|
out_torch = dedicom_torch(inputs_torch, inputs_torch)
|
|
|
|
|
|
out_tf = dedicom_tf(inputs_tf)
|
|
|
|
|
|
|
|
|
out_torch = decoder_torch(inputs_torch, inputs_torch)
|
|
|
|
|
|
out_tf = decoder_tf(inputs_tf)
|
|
|
|
|
|
|
|
|
assert len(out_torch) == len(out_tf)
|
|
|
assert len(out_torch) == len(out_tf)
|
|
|
assert len(out_tf) == 7
|
|
|
assert len(out_tf) == 7
|
|
@@ -39,3 +28,31 @@ def test_dedicom(): |
|
|
print('item_tf:', item_tf)
|
|
|
print('item_tf:', item_tf)
|
|
|
assert np.all(item_torch - item_tf < .000001)
|
|
|
assert np.all(item_torch - item_tf < .000001)
|
|
|
sess.close()
|
|
|
sess.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dedicom_decoder():
|
|
|
|
|
|
dedicom_torch = decagon_pytorch.decode.DEDICOMDecoder(input_dim=10,
|
|
|
|
|
|
num_relation_types=7)
|
|
|
|
|
|
dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
|
|
|
|
|
|
edge_type=(0, 0))
|
|
|
|
|
|
|
|
|
|
|
|
dedicom_tf.vars['global_interaction'] = \
|
|
|
|
|
|
tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
|
|
|
|
|
|
for i in range(dedicom_tf.num_types):
|
|
|
|
|
|
dedicom_tf.vars['local_variation_%d' % i] = \
|
|
|
|
|
|
tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
_common(dedicom_torch, dedicom_tf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dist_mult_decoder():
|
|
|
|
|
|
distmult_torch = decagon_pytorch.decode.DistMultDecoder(input_dim=10,
|
|
|
|
|
|
num_relation_types=7)
|
|
|
|
|
|
distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
|
|
|
|
|
|
edge_type=(0, 0))
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(distmult_tf.num_types):
|
|
|
|
|
|
distmult_tf.vars['relation_%d' % i] = \
|
|
|
|
|
|
tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
_common(distmult_torch, distmult_tf)
|