diff --git a/src/icosagon/bulkdec.py b/src/icosagon/bulkdec.py new file mode 100644 index 0000000..ea6dba8 --- /dev/null +++ b/src/icosagon/bulkdec.py @@ -0,0 +1,119 @@ +from icosagon.data import Data +from icosagon.trainprep import PreparedData +from icosagon.decode import DEDICOMDecoder, \ + DistMultDecoder, \ + BilinearDecoder, \ + InnerProductDecoder +from icosagon.dropout import dropout +import torch +from typing import List, \ + Callable, \ + Union + + +''' +Let's say that I have dense latent representations row and col. +Then let's take relation matrix rel in a list of relations REL. +A single computation currenty looks like this: +(((row * rel) * glob) * rel) * col + +Shouldn't then this basically work: + +prod1 = torch.matmul(row, REL) +prod2 = torch.matmul(prod1, glob) +prod3 = torch.matmul(prod2, REL) +res = torch.matmul(prod3, col) +res = activation(res) + +res should then have shape: (num_relations, num_rows, num_columns) +''' + + +def convert_decoder(dec): + if isinstance(dec, DEDICOMDecoder): + global_interaction = dec.global_interaction + local_variation = map(torch.diag, dec.local_variation) + elif isinstance(dec, DistMultDecoder): + global_interaction = torch.eye(dec.input_dim, dec.input_dim) + local_variation = map(torch.diag, dec.relation) + elif isinstance(dec, BilinearDecoder): + global_interaction = torch.eye(dec.input_dim, dec.input_dim) + local_variation = dec.relation + elif isinstance(dec, InnerProductDecoder): + global_interaction = torch.eye(dec.input_dim, dec.input_dim) + local_variation = torch.eye(dec.input_dim, dec.input_dim) + local_variation = [ local_variation ] * dec.num_relation_types + else: + raise TypeError('Unknown decoder type in covert_decoder()') + + if not isinstance(local_variation, torch.Tensor): + local_variation = map(lambda a: a.view(1, *a.shape), local_variation) + local_variation = torch.cat(list(local_variation)) + + return (global_interaction, local_variation) + + +class BulkDecodeLayer(torch.nn.Module): + def __init__(self, + input_dim: List[int], + data: Union[Data, PreparedData], + keep_prob: float = 1., + activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, + **kwargs) -> None: + + super().__init__(**kwargs) + + self._check_params(input_dim, data) + + self.input_dim = input_dim[0] + self.data = data + self.keep_prob = keep_prob + self.activation = activation + + self.decoders = None + self.global_interaction = None + self.local_variation = None + self.build() + + def build(self) -> None: + self.decoders = torch.nn.ModuleList() + self.global_interaction = torch.nn.ParameterList() + self.local_variation = torch.nn.ParameterList() + for fam in self.data.relation_families: + dec = fam.decoder_class(self.input_dim, + len(fam.relation_types), + self.keep_prob, + self.activation) + self.decoders.append(dec) + global_interaction, local_variation = convert_decoder(dec) + self.global_interaction.append(torch.nn.Parameter(global_interaction)) + self.local_variation.append(torch.nn.Parameter(local_variation)) + + def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: + res = [] + for i, fam in enumerate(self.data.relation_families): + repr_row = last_layer_repr[fam.node_type_row] + repr_column = last_layer_repr[fam.node_type_column] + repr_row = dropout(repr_row, keep_prob=self.keep_prob) + repr_column = dropout(repr_column, keep_prob=self.keep_prob) + prod_1 = torch.matmul(repr_row, self.local_variation[i]) + print(f'local_variation[{i}].shape: {self.local_variation[i].shape}') + prod_2 = torch.matmul(prod_1, self.global_interaction[i]) + prod_3 = torch.matmul(prod_2, self.local_variation[i]) + pred = torch.matmul(prod_3, repr_column.transpose(0, 1)) + res.append(pred) + return res + + @staticmethod + def _check_params(input_dim, data): + if not isinstance(input_dim, list): + raise TypeError('input_dim must be a list') + + if len(input_dim) != len(data.node_types): + raise ValueError('input_dim must have length equal to num_node_types') + + if not all([ a == input_dim[0] for a in input_dim ]): + raise ValueError('All elements of input_dim must have the same value') + + if not isinstance(data, Data) and not isinstance(data, PreparedData): + raise TypeError('data must be an instance of Data or PreparedData') diff --git a/tests/icosagon/test_bulkdec.py b/tests/icosagon/test_bulkdec.py new file mode 100644 index 0000000..c1e4670 --- /dev/null +++ b/tests/icosagon/test_bulkdec.py @@ -0,0 +1,113 @@ +from icosagon.data import Data +from icosagon.bulkdec import BulkDecodeLayer +from icosagon.input import OneHotInputLayer +from icosagon.convlayer import DecagonLayer +import torch + + +def test_bulk_decode_layer_01(): + data = Data() + data.add_node_type('Dummy', 100) + fam = data.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Relation 1', + torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(data) + d_layer = DecagonLayer(in_layer.output_dim, 32, data) + dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data, + keep_prob=1., activation=lambda x: x) + seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) + + pred = seq(None) + + assert isinstance(pred, list) + assert len(pred) == len(data.relation_families) + assert isinstance(pred[0], torch.Tensor) + assert len(pred[0].shape) == 3 + assert len(pred[0]) == len(data.relation_families[0].relation_types) + assert pred[0].shape[1] == data.node_types[0].count + assert pred[0].shape[2] == data.node_types[0].count + + +def test_bulk_decode_layer_02(): + data = Data() + data.add_node_type('Foo', 100) + data.add_node_type('Bar', 50) + fam = data.add_relation_family('Foo-Bar', 0, 1, False) + fam.add_relation_type('Foobar Relation 1', + torch.rand((100, 50), dtype=torch.float32).round().to_sparse(), + torch.rand((50, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(data) + d_layer = DecagonLayer(in_layer.output_dim, 32, data) + dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data, + keep_prob=1., activation=lambda x: x) + seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) + + pred = seq(None) + + assert isinstance(pred, list) + assert len(pred) == len(data.relation_families) + assert isinstance(pred[0], torch.Tensor) + assert len(pred[0].shape) == 3 + assert len(pred[0]) == len(data.relation_families[0].relation_types) + assert pred[0].shape[1] == data.node_types[0].count + assert pred[0].shape[2] == data.node_types[1].count + + +def test_bulk_decode_layer_03(): + data = Data() + data.add_node_type('Foo', 100) + data.add_node_type('Bar', 50) + fam = data.add_relation_family('Foo-Bar', 0, 1, False) + fam.add_relation_type('Foobar Relation 1', + torch.rand((100, 50), dtype=torch.float32).round().to_sparse(), + torch.rand((50, 100), dtype=torch.float32).round().to_sparse()) + fam.add_relation_type('Foobar Relation 2', + torch.rand((100, 50), dtype=torch.float32).round().to_sparse(), + torch.rand((50, 100), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(data) + d_layer = DecagonLayer(in_layer.output_dim, 32, data) + dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data, + keep_prob=1., activation=lambda x: x) + seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) + + pred = seq(None) + + assert isinstance(pred, list) + assert len(pred) == len(data.relation_families) + assert isinstance(pred[0], torch.Tensor) + assert len(pred[0].shape) == 3 + assert len(pred[0]) == len(data.relation_families[0].relation_types) + assert pred[0].shape[1] == data.node_types[0].count + assert pred[0].shape[2] == data.node_types[1].count + + +def test_bulk_decode_layer_03_big(): + data = Data() + data.add_node_type('Foo', 2000) + data.add_node_type('Bar', 2100) + fam = data.add_relation_family('Foo-Bar', 0, 1, False) + fam.add_relation_type('Foobar Relation 1', + torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(), + torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse()) + fam.add_relation_type('Foobar Relation 2', + torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(), + torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse()) + + in_layer = OneHotInputLayer(data) + d_layer = DecagonLayer(in_layer.output_dim, 32, data) + dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data, + keep_prob=1., activation=lambda x: x) + seq = torch.nn.Sequential(in_layer, d_layer, dec_layer) + + pred = seq(None) + + assert isinstance(pred, list) + assert len(pred) == len(data.relation_families) + assert isinstance(pred[0], torch.Tensor) + assert len(pred[0].shape) == 3 + assert len(pred[0]) == len(data.relation_families[0].relation_types) + assert pred[0].shape[1] == data.node_types[0].count + assert pred[0].shape[2] == data.node_types[1].count