|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from icosagon.data import Data
- from icosagon.trainprep import PreparedData
- from icosagon.decode import DEDICOMDecoder, \
- DistMultDecoder, \
- BilinearDecoder, \
- InnerProductDecoder
- from icosagon.dropout import dropout
- import torch
- from typing import List, \
- Callable, \
- Union
-
-
- '''
- Let's say that I have dense latent representations row and col.
- Then let's take relation matrix rel in a list of relations REL.
- A single computation currenty looks like this:
- (((row * rel) * glob) * rel) * col
-
- Shouldn't then this basically work:
-
- prod1 = torch.matmul(row, REL)
- prod2 = torch.matmul(prod1, glob)
- prod3 = torch.matmul(prod2, REL)
- res = torch.matmul(prod3, col)
- res = activation(res)
-
- res should then have shape: (num_relations, num_rows, num_columns)
- '''
-
-
- def convert_decoder(dec):
- if isinstance(dec, DEDICOMDecoder):
- global_interaction = dec.global_interaction
- local_variation = map(torch.diag, dec.local_variation)
- elif isinstance(dec, DistMultDecoder):
- global_interaction = torch.eye(dec.input_dim, dec.input_dim)
- local_variation = map(torch.diag, dec.relation)
- elif isinstance(dec, BilinearDecoder):
- global_interaction = torch.eye(dec.input_dim, dec.input_dim)
- local_variation = dec.relation
- elif isinstance(dec, InnerProductDecoder):
- global_interaction = torch.eye(dec.input_dim, dec.input_dim)
- local_variation = torch.eye(dec.input_dim, dec.input_dim)
- local_variation = [ local_variation ] * dec.num_relation_types
- else:
- raise TypeError('Unknown decoder type in convert_decoder()')
-
- if not isinstance(local_variation, torch.Tensor):
- local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
- local_variation = torch.cat(list(local_variation))
-
- return (global_interaction, local_variation)
-
-
- class BulkDecodeLayer(torch.nn.Module):
- def __init__(self,
- input_dim: List[int],
- data: Union[Data, PreparedData],
- keep_prob: float = 1.,
- activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
- **kwargs) -> None:
-
- super().__init__(**kwargs)
-
- self._check_params(input_dim, data)
-
- self.input_dim = input_dim[0]
- self.data = data
- self.keep_prob = keep_prob
- self.activation = activation
-
- self.decoders = None
- self.global_interaction = None
- self.local_variation = None
- self.build()
-
- def build(self) -> None:
- self.decoders = torch.nn.ModuleList()
- self.global_interaction = torch.nn.ParameterList()
- self.local_variation = torch.nn.ParameterList()
- for fam in self.data.relation_families:
- dec = fam.decoder_class(self.input_dim,
- len(fam.relation_types),
- self.keep_prob,
- self.activation)
- self.decoders.append(dec)
- global_interaction, local_variation = convert_decoder(dec)
- self.global_interaction.append(torch.nn.Parameter(global_interaction))
- self.local_variation.append(torch.nn.Parameter(local_variation))
-
- def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
- res = []
- for i, fam in enumerate(self.data.relation_families):
- repr_row = last_layer_repr[fam.node_type_row]
- repr_column = last_layer_repr[fam.node_type_column]
- repr_row = dropout(repr_row, keep_prob=self.keep_prob)
- repr_column = dropout(repr_column, keep_prob=self.keep_prob)
- prod_1 = torch.matmul(repr_row, self.local_variation[i])
- print(f'local_variation[{i}].shape: {self.local_variation[i].shape}')
- prod_2 = torch.matmul(prod_1, self.global_interaction[i])
- prod_3 = torch.matmul(prod_2, self.local_variation[i])
- pred = torch.matmul(prod_3, repr_column.transpose(0, 1))
- res.append(pred)
- return res
-
- @staticmethod
- def _check_params(input_dim, data):
- if not isinstance(input_dim, list):
- raise TypeError('input_dim must be a list')
-
- if len(input_dim) != len(data.node_types):
- raise ValueError('input_dim must have length equal to num_node_types')
-
- if not all([ a == input_dim[0] for a in input_dim ]):
- raise ValueError('All elements of input_dim must have the same value')
-
- if not isinstance(data, Data) and not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of Data or PreparedData')
|