from icosagon.data import Data from icosagon.trainprep import PreparedData from icosagon.decode import DEDICOMDecoder, \ DistMultDecoder, \ BilinearDecoder, \ InnerProductDecoder from icosagon.dropout import dropout import torch from typing import List, \ Callable, \ Union ''' Let's say that I have dense latent representations row and col. Then let's take relation matrix rel in a list of relations REL. A single computation currenty looks like this: (((row * rel) * glob) * rel) * col Shouldn't then this basically work: prod1 = torch.matmul(row, REL) prod2 = torch.matmul(prod1, glob) prod3 = torch.matmul(prod2, REL) res = torch.matmul(prod3, col) res = activation(res) res should then have shape: (num_relations, num_rows, num_columns) ''' def convert_decoder(dec): if isinstance(dec, DEDICOMDecoder): global_interaction = dec.global_interaction local_variation = map(torch.diag, dec.local_variation) elif isinstance(dec, DistMultDecoder): global_interaction = torch.eye(dec.input_dim, dec.input_dim) local_variation = map(torch.diag, dec.relation) elif isinstance(dec, BilinearDecoder): global_interaction = torch.eye(dec.input_dim, dec.input_dim) local_variation = dec.relation elif isinstance(dec, InnerProductDecoder): global_interaction = torch.eye(dec.input_dim, dec.input_dim) local_variation = torch.eye(dec.input_dim, dec.input_dim) local_variation = [ local_variation ] * dec.num_relation_types else: raise TypeError('Unknown decoder type in convert_decoder()') if not isinstance(local_variation, torch.Tensor): local_variation = map(lambda a: a.view(1, *a.shape), local_variation) local_variation = torch.cat(list(local_variation)) return (global_interaction, local_variation) class BulkDecodeLayer(torch.nn.Module): def __init__(self, input_dim: List[int], data: Union[Data, PreparedData], keep_prob: float = 1., activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, **kwargs) -> None: super().__init__(**kwargs) self._check_params(input_dim, data) self.input_dim = input_dim[0] self.data = data self.keep_prob = keep_prob self.activation = activation self.decoders = None self.global_interaction = None self.local_variation = None self.build() def build(self) -> None: self.decoders = torch.nn.ModuleList() self.global_interaction = torch.nn.ParameterList() self.local_variation = torch.nn.ParameterList() for fam in self.data.relation_families: dec = fam.decoder_class(self.input_dim, len(fam.relation_types), self.keep_prob, self.activation) self.decoders.append(dec) global_interaction, local_variation = convert_decoder(dec) self.global_interaction.append(torch.nn.Parameter(global_interaction)) self.local_variation.append(torch.nn.Parameter(local_variation)) def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]: res = [] for i, fam in enumerate(self.data.relation_families): repr_row = last_layer_repr[fam.node_type_row] repr_column = last_layer_repr[fam.node_type_column] repr_row = dropout(repr_row, keep_prob=self.keep_prob) repr_column = dropout(repr_column, keep_prob=self.keep_prob) prod_1 = torch.matmul(repr_row, self.local_variation[i]) print(f'local_variation[{i}].shape: {self.local_variation[i].shape}') prod_2 = torch.matmul(prod_1, self.global_interaction[i]) prod_3 = torch.matmul(prod_2, self.local_variation[i]) pred = torch.matmul(prod_3, repr_column.transpose(0, 1)) res.append(pred) return res @staticmethod def _check_params(input_dim, data): if not isinstance(input_dim, list): raise TypeError('input_dim must be a list') if len(input_dim) != len(data.node_types): raise ValueError('input_dim must have length equal to num_node_types') if not all([ a == input_dim[0] for a in input_dim ]): raise ValueError('All elements of input_dim must have the same value') if not isinstance(data, Data) and not isinstance(data, PreparedData): raise TypeError('data must be an instance of Data or PreparedData')