|
|
@@ -0,0 +1,110 @@ |
|
|
|
from typing import List, \
|
|
|
|
Union, \
|
|
|
|
Callable
|
|
|
|
from .data import Data
|
|
|
|
from .trainprep import PreparedData
|
|
|
|
import torch
|
|
|
|
from .weights import init_glorot
|
|
|
|
|
|
|
|
|
|
|
|
class FastConvLayer(torch.nn.Module):
|
|
|
|
adjacency_matrix: List[torch.Tensor]
|
|
|
|
adjacency_matrix_backward: List[torch.Tensor]
|
|
|
|
weight: List[torch.Tensor]
|
|
|
|
weight_backward: List[torch.Tensor]
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
input_dim: List[int],
|
|
|
|
output_dim: List[int],
|
|
|
|
data: Union[Data, PreparedData],
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x
|
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
self._check_params(input_dim, output_dim, data, keep_prob,
|
|
|
|
rel_activation, layer_activation)
|
|
|
|
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.data = data
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.rel_activation = rel_activation
|
|
|
|
self.layer_activation = layer_activation
|
|
|
|
|
|
|
|
self.adjacency_matrix = None
|
|
|
|
self.adjacency_matrix_backward = None
|
|
|
|
self.weight = None
|
|
|
|
self.weight_backward = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
self.adjacency_matrix = []
|
|
|
|
self.adjacency_matrix_backward = []
|
|
|
|
self.weight = []
|
|
|
|
self.weight_backward = []
|
|
|
|
for fam in self.data.relation_families:
|
|
|
|
adj_mat = [ rel.adjacency_matrix \
|
|
|
|
for rel in fam.relation_types \
|
|
|
|
if rel.adjacency_matrix is not None ]
|
|
|
|
adj_mat_back = [ rel.adjacency_matrix_backward \
|
|
|
|
for rel in fam.relation_types \
|
|
|
|
if rel.adjacency_matrix_backward is not None ]
|
|
|
|
weight = [ init_glorot(self.input_dim[fam.node_type_column],
|
|
|
|
self.output_dim[fam.node_type_row]) \
|
|
|
|
for _ in range(len(adj_mat)) ]
|
|
|
|
weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
|
|
|
|
self.output_dim[fam.node_type_row]) \
|
|
|
|
for _ in range(len(adj_mat_back)) ]
|
|
|
|
adj_mat = torch.cat(adj_mat) \
|
|
|
|
if len(adj_mat) > 0 \
|
|
|
|
else None
|
|
|
|
adj_mat_back = torch.cat(adj_mat_back) \
|
|
|
|
if len(adj_mat_back) > 0 \
|
|
|
|
else None
|
|
|
|
self.adjacency_matrix.append(adj_mat)
|
|
|
|
self.adjacency_matrix_backward.append(adj_mat_back)
|
|
|
|
self.weight.append(weight)
|
|
|
|
self.weight_back.append(weight_back)
|
|
|
|
|
|
|
|
def forward(self, prev_layer_repr):
|
|
|
|
for i, fam in enumerate(self.data.relation_families):
|
|
|
|
repr_row = prev_layer_repr[fam.node_type_row]
|
|
|
|
repr_column = prev_layer_repr[fam.node_type_column]
|
|
|
|
|
|
|
|
adj_mat = self.adjacency_matrix[i]
|
|
|
|
adj_mat_back = self.adjacency_matrix_backward[i]
|
|
|
|
|
|
|
|
if adj_mat is not None:
|
|
|
|
x = dropout(repr_column, keep_prob=self.keep_prob)
|
|
|
|
x = torch.sparse.mm(x, self.weight[i]) \
|
|
|
|
if x.is_sparse \
|
|
|
|
else torch.mm(x, self.weight[i])
|
|
|
|
x = torch.sparse.mm(adj_mat, repr_row) \
|
|
|
|
if adj_mat.is_sparse \
|
|
|
|
else torch.mm(adj_mat, repr_row)
|
|
|
|
x = self.rel_activation(x)
|
|
|
|
x = x.view(len(fam.relation_types), len(repr_row), -1)
|
|
|
|
|
|
|
|
if adj_mat_back is not None:
|
|
|
|
x = torch.sparse.mm(adj_mat_back, repr_row) \
|
|
|
|
if adj_mat_back.is_sparse \
|
|
|
|
else torch.mm(adj_mat_back, repr_row)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _check_params(input_dim, output_dim, data, keep_prob,
|
|
|
|
rel_activation, layer_activation):
|
|
|
|
|
|
|
|
if not isinstance(input_dim, list):
|
|
|
|
raise ValueError('input_dim must be a list')
|
|
|
|
|
|
|
|
if not output_dim:
|
|
|
|
raise ValueError('output_dim must be specified')
|
|
|
|
|
|
|
|
if not isinstance(output_dim, list):
|
|
|
|
output_dim = [output_dim] * len(data.node_types)
|
|
|
|
|
|
|
|
if not isinstance(data, Data) and not isinstance(data, PreparedData):
|
|
|
|
raise ValueError('data must be of type Data or PreparedData')
|