|
|
@@ -1,11 +1,14 @@ |
|
|
|
from typing import List, \
|
|
|
|
Union, \
|
|
|
|
Callable
|
|
|
|
from .data import Data
|
|
|
|
from .trainprep import PreparedData
|
|
|
|
from .data import Data, \
|
|
|
|
RelationFamily
|
|
|
|
from .trainprep import PreparedData, \
|
|
|
|
PreparedRelationFamily
|
|
|
|
import torch
|
|
|
|
from .weights import init_glorot
|
|
|
|
from .normalize import _sparse_coo_tensor
|
|
|
|
import types
|
|
|
|
|
|
|
|
|
|
|
|
def _sparse_diag_cat(matrices: List[torch.Tensor]):
|
|
|
@@ -75,28 +78,80 @@ def _cat(matrices: List[torch.Tensor]): |
|
|
|
|
|
|
|
class FastGraphConv(torch.nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
in_channels: int,
|
|
|
|
out_channels: int,
|
|
|
|
adjacency_matrix: List[torch.Tensor],
|
|
|
|
**kwargs):
|
|
|
|
in_channels: List[int],
|
|
|
|
out_channels: List[int],
|
|
|
|
data: Union[Data, PreparedData],
|
|
|
|
relation_family: Union[RelationFamily, PreparedRelationFamily]
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
in_channels = int(in_channels)
|
|
|
|
out_channels = int(out_channels)
|
|
|
|
if not isinstance(data, Data) and not isinstance(data, PreparedData):
|
|
|
|
raise TypeError('data must be an instance of Data or PreparedData')
|
|
|
|
if not isinstance(relation_family, RelationFamily) and \
|
|
|
|
not isinstance(relation_family, PreparedRelationFamily):
|
|
|
|
raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
|
|
|
|
keep_prob = float(keep_prob)
|
|
|
|
if not isinstance(activation, types.FunctionType):
|
|
|
|
raise TypeError('activation must be a function')
|
|
|
|
|
|
|
|
n_nodes_row = data.node_types[relation_family.node_type_row].count
|
|
|
|
n_nodes_column = data.node_types[relation_family.node_type_column].count
|
|
|
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.data = data
|
|
|
|
self.relation_family = relation_family
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.weight = torch.cat([
|
|
|
|
init_glorot(in_channels, out_channels) \
|
|
|
|
for _ in adjacency_matrix
|
|
|
|
for _ in range(len(relation_family.relation_types))
|
|
|
|
], dim=1)
|
|
|
|
self.adjacency_matrix = _cat(adjacency_matrix)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.sparse.mm(x, self.weight) \
|
|
|
|
if x.is_sparse \
|
|
|
|
else torch.mm(x, self.weight)
|
|
|
|
x = torch.sparse.mm(self.adjacency_matrix, x) \
|
|
|
|
if self.adjacency_matrix.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrix, x)
|
|
|
|
return x
|
|
|
|
self.weight_backward = torch.cat([
|
|
|
|
init_glorot(in_channels, out_channels) \
|
|
|
|
for _ in range(len(relation_family.relation_types))
|
|
|
|
], dim=1)
|
|
|
|
|
|
|
|
self.adjacency_matrix = _sparse_diag_cat([
|
|
|
|
rel.adjacency_matrix \
|
|
|
|
if rel.adjacency_matrix is not None \
|
|
|
|
else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
|
|
|
|
for rel in relation_family.relation_types ])
|
|
|
|
|
|
|
|
self.adjacency_matrix_backward = _sparse_diag_cat([
|
|
|
|
rel.adjacency_matrix_backward \
|
|
|
|
if rel.adjacency_matrix_backward is not None \
|
|
|
|
else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
|
|
|
|
for rel in relation_family.relation_types ])
|
|
|
|
|
|
|
|
def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
repr_row = prev_layer_repr[self.relation_family.node_type_row]
|
|
|
|
repr_column = prev_layer_repr[self.relation_family.node_type_column]
|
|
|
|
|
|
|
|
new_repr_row = torch.sparse.mm(repr_column, self.weight) \
|
|
|
|
if repr_column.is_sparse \
|
|
|
|
else torch.mm(repr_column, self.weight)
|
|
|
|
new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
|
|
|
|
if self.adjacency_matrix.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrix, new_repr_row)
|
|
|
|
new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
|
|
|
|
len(repr_row), self.out_channels)
|
|
|
|
|
|
|
|
new_repr_column = torch.sparse.mm(repr_row, self.weight) \
|
|
|
|
if repr_row.is_sparse \
|
|
|
|
else torch.mm(repr_row, self.weight)
|
|
|
|
new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
|
|
|
|
if self.adjacency_matrix_backward.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrix_backward, new_repr_column)
|
|
|
|
new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
|
|
|
|
len(repr_column), self.out_channels)
|
|
|
|
|
|
|
|
return (new_repr_row, new_repr_column)
|
|
|
|
|
|
|
|
|
|
|
|
class FastConvLayer(torch.nn.Module):
|
|
|
|