diff --git a/src/decagon_pytorch/decode/pairwise.py b/src/decagon_pytorch/decode/pairwise.py index 910a8ce..93ff68c 100644 --- a/src/decagon_pytorch/decode/pairwise.py +++ b/src/decagon_pytorch/decode/pairwise.py @@ -37,7 +37,9 @@ class DEDICOMDecoder(torch.nn.Module): product1 = torch.mm(inputs_row, relation) product2 = torch.mm(product1, self.global_interaction) product3 = torch.mm(product2, relation) - rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1)) + rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) outputs.append(self.activation(rec)) return outputs @@ -67,7 +69,9 @@ class DistMultDecoder(torch.nn.Module): relation = torch.diag(self.relation[k]) intermediate_product = torch.mm(inputs_row, relation) - rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) outputs.append(self.activation(rec)) return outputs @@ -95,7 +99,9 @@ class BilinearDecoder(torch.nn.Module): inputs_col = dropout(inputs_col, 1.-self.drop_prob) intermediate_product = torch.mm(inputs_row, self.relation[k]) - rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1)) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) outputs.append(self.activation(rec)) return outputs @@ -118,6 +124,8 @@ class InnerProductDecoder(torch.nn.Module): inputs_row = dropout(inputs_row, 1.-self.drop_prob) inputs_col = dropout(inputs_col, 1.-self.drop_prob) - rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1)) + rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) outputs.append(self.activation(rec)) return outputs diff --git a/tests/decagon_pytorch/test_decode_pairwise.py b/tests/decagon_pytorch/test_decode_pairwise.py new file mode 100644 index 0000000..9b47eef --- /dev/null +++ b/tests/decagon_pytorch/test_decode_pairwise.py @@ -0,0 +1,59 @@ +import decagon_pytorch.decode.cartesian as cart +import decagon_pytorch.decode.pairwise as pair +import torch + + +def _common(cart_class, pair_class): + input_dim = 10 + n_nodes = 20 + num_relation_types = 7 + + inputs_row = torch.rand((n_nodes, input_dim)) + inputs_col = torch.rand((n_nodes, input_dim)) + + cart_dec = cart_class(input_dim=input_dim, + num_relation_types=num_relation_types) + pair_dec = pair_class(input_dim=input_dim, + num_relation_types=num_relation_types) + + if isinstance(cart_dec, cart.DEDICOMDecoder): + pair_dec.global_interaction = cart_dec.global_interaction + pair_dec.local_variation = cart_dec.local_variation + elif isinstance(cart_dec, cart.InnerProductDecoder): + pass + else: + pair_dec.relation = cart_dec.relation + + cart_pred = cart_dec(inputs_row, inputs_col) + pair_pred = pair_dec(inputs_row, inputs_col) + + assert isinstance(cart_pred, list) + assert isinstance(pair_pred, list) + + assert len(cart_pred) == len(pair_pred) + assert len(cart_pred) == num_relation_types + + for i in range(num_relation_types): + assert isinstance(cart_pred[i], torch.Tensor) + assert isinstance(pair_pred[i], torch.Tensor) + + assert cart_pred[i].shape == (n_nodes, n_nodes) + assert pair_pred[i].shape == (n_nodes,) + + assert torch.all(torch.abs(pair_pred[i] - torch.diag(cart_pred[i])) < 0.000001) + + +def test_dedicom_decoder(): + _common(cart.DEDICOMDecoder, pair.DEDICOMDecoder) + + +def test_dist_mult_decoder(): + _common(cart.DistMultDecoder, pair.DistMultDecoder) + + +def test_bilinear_decoder(): + _common(cart.BilinearDecoder, pair.BilinearDecoder) + + +def test_inner_product_decoder(): + _common(cart.InnerProductDecoder, pair.InnerProductDecoder)