IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Implement and test pairwise decoders.

master
Stanislaw Adaszewski 3 years ago
parent
commit
39aa1b4211
2 changed files with 71 additions and 4 deletions
  1. +12
    -4
      src/decagon_pytorch/decode/pairwise.py
  2. +59
    -0
      tests/decagon_pytorch/test_decode_pairwise.py

+ 12
- 4
src/decagon_pytorch/decode/pairwise.py View File

@@ -37,7 +37,9 @@ class DEDICOMDecoder(torch.nn.Module):
product1 = torch.mm(inputs_row, relation)
product2 = torch.mm(product1, self.global_interaction)
product3 = torch.mm(product2, relation)
rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
@@ -67,7 +69,9 @@ class DistMultDecoder(torch.nn.Module):
relation = torch.diag(self.relation[k])
intermediate_product = torch.mm(inputs_row, relation)
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
@@ -95,7 +99,9 @@ class BilinearDecoder(torch.nn.Module):
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
intermediate_product = torch.mm(inputs_row, self.relation[k])
rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
@@ -118,6 +124,8 @@ class InnerProductDecoder(torch.nn.Module):
inputs_row = dropout(inputs_row, 1.-self.drop_prob)
inputs_col = dropout(inputs_col, 1.-self.drop_prob)
rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1))
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs

+ 59
- 0
tests/decagon_pytorch/test_decode_pairwise.py View File

@@ -0,0 +1,59 @@
import decagon_pytorch.decode.cartesian as cart
import decagon_pytorch.decode.pairwise as pair
import torch
def _common(cart_class, pair_class):
input_dim = 10
n_nodes = 20
num_relation_types = 7
inputs_row = torch.rand((n_nodes, input_dim))
inputs_col = torch.rand((n_nodes, input_dim))
cart_dec = cart_class(input_dim=input_dim,
num_relation_types=num_relation_types)
pair_dec = pair_class(input_dim=input_dim,
num_relation_types=num_relation_types)
if isinstance(cart_dec, cart.DEDICOMDecoder):
pair_dec.global_interaction = cart_dec.global_interaction
pair_dec.local_variation = cart_dec.local_variation
elif isinstance(cart_dec, cart.InnerProductDecoder):
pass
else:
pair_dec.relation = cart_dec.relation
cart_pred = cart_dec(inputs_row, inputs_col)
pair_pred = pair_dec(inputs_row, inputs_col)
assert isinstance(cart_pred, list)
assert isinstance(pair_pred, list)
assert len(cart_pred) == len(pair_pred)
assert len(cart_pred) == num_relation_types
for i in range(num_relation_types):
assert isinstance(cart_pred[i], torch.Tensor)
assert isinstance(pair_pred[i], torch.Tensor)
assert cart_pred[i].shape == (n_nodes, n_nodes)
assert pair_pred[i].shape == (n_nodes,)
assert torch.all(torch.abs(pair_pred[i] - torch.diag(cart_pred[i])) < 0.000001)
def test_dedicom_decoder():
_common(cart.DEDICOMDecoder, pair.DEDICOMDecoder)
def test_dist_mult_decoder():
_common(cart.DistMultDecoder, pair.DistMultDecoder)
def test_bilinear_decoder():
_common(cart.BilinearDecoder, pair.BilinearDecoder)
def test_inner_product_decoder():
_common(cart.InnerProductDecoder, pair.InnerProductDecoder)

Loading…
Cancel
Save