|
|
@@ -78,80 +78,61 @@ def _cat(matrices: List[torch.Tensor]): |
|
|
|
|
|
|
|
class FastGraphConv(torch.nn.Module):
|
|
|
|
def __init__(self,
|
|
|
|
in_channels: List[int],
|
|
|
|
out_channels: List[int],
|
|
|
|
data: Union[Data, PreparedData],
|
|
|
|
relation_family: Union[RelationFamily, PreparedRelationFamily],
|
|
|
|
in_channels: int,
|
|
|
|
out_channels: int,
|
|
|
|
adjacency_matrices: List[torch.Tensor],
|
|
|
|
keep_prob: float = 1.,
|
|
|
|
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
|
**kwargs) -> None:
|
|
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
in_channels = int(in_channels)
|
|
|
|
out_channels = int(out_channels)
|
|
|
|
if not isinstance(data, Data) and not isinstance(data, PreparedData):
|
|
|
|
raise TypeError('data must be an instance of Data or PreparedData')
|
|
|
|
if not isinstance(relation_family, RelationFamily) and \
|
|
|
|
not isinstance(relation_family, PreparedRelationFamily):
|
|
|
|
raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
|
|
|
|
if not isinstance(adjacency_matrices, list):
|
|
|
|
raise TypeError('adjacency_matrices must be a list')
|
|
|
|
if len(adjacency_matrices) == 0:
|
|
|
|
raise ValueError('adjacency_matrices must not be empty')
|
|
|
|
if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
|
|
|
|
raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
|
|
|
|
if not all(m.is_sparse for m in adjacency_matrices):
|
|
|
|
raise ValueError('adjacency_matrices elements must be sparse')
|
|
|
|
keep_prob = float(keep_prob)
|
|
|
|
if not isinstance(activation, types.FunctionType):
|
|
|
|
raise TypeError('activation must be a function')
|
|
|
|
|
|
|
|
n_nodes_row = data.node_types[relation_family.node_type_row].count
|
|
|
|
n_nodes_column = data.node_types[relation_family.node_type_column].count
|
|
|
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
self.out_channels = out_channels
|
|
|
|
self.data = data
|
|
|
|
self.relation_family = relation_family
|
|
|
|
self.adjacency_matrices = adjacency_matrices
|
|
|
|
self.keep_prob = keep_prob
|
|
|
|
self.activation = activation
|
|
|
|
|
|
|
|
self.weight = torch.cat([
|
|
|
|
init_glorot(in_channels, out_channels) \
|
|
|
|
for _ in range(len(relation_family.relation_types))
|
|
|
|
], dim=1)
|
|
|
|
self.num_row_nodes = len(adjacency_matrices[0])
|
|
|
|
self.num_relation_types = len(adjacency_matrices)
|
|
|
|
|
|
|
|
self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
|
|
|
|
|
|
|
|
self.weight_backward = torch.cat([
|
|
|
|
self.weights = torch.cat([
|
|
|
|
init_glorot(in_channels, out_channels) \
|
|
|
|
for _ in range(len(relation_family.relation_types))
|
|
|
|
for _ in range(self.num_relation_types)
|
|
|
|
], dim=1)
|
|
|
|
|
|
|
|
self.adjacency_matrix = _sparse_diag_cat([
|
|
|
|
rel.adjacency_matrix \
|
|
|
|
if rel.adjacency_matrix is not None \
|
|
|
|
else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
|
|
|
|
for rel in relation_family.relation_types ])
|
|
|
|
|
|
|
|
self.adjacency_matrix_backward = _sparse_diag_cat([
|
|
|
|
rel.adjacency_matrix_backward \
|
|
|
|
if rel.adjacency_matrix_backward is not None \
|
|
|
|
else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
|
|
|
|
for rel in relation_family.relation_types ])
|
|
|
|
|
|
|
|
def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
repr_row = prev_layer_repr[self.relation_family.node_type_row]
|
|
|
|
repr_column = prev_layer_repr[self.relation_family.node_type_column]
|
|
|
|
|
|
|
|
new_repr_row = torch.sparse.mm(repr_column, self.weight) \
|
|
|
|
if repr_column.is_sparse \
|
|
|
|
else torch.mm(repr_column, self.weight)
|
|
|
|
new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
|
|
|
|
if self.adjacency_matrix.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrix, new_repr_row)
|
|
|
|
new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
|
|
|
|
len(repr_row), self.out_channels)
|
|
|
|
|
|
|
|
new_repr_column = torch.sparse.mm(repr_row, self.weight) \
|
|
|
|
if repr_row.is_sparse \
|
|
|
|
else torch.mm(repr_row, self.weight)
|
|
|
|
new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
|
|
|
|
if self.adjacency_matrix_backward.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrix_backward, new_repr_column)
|
|
|
|
new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
|
|
|
|
len(repr_column), self.out_channels)
|
|
|
|
|
|
|
|
return (new_repr_row, new_repr_column)
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
|
|
if self.keep_prob < 1.:
|
|
|
|
x = dropout(x, self.keep_prob)
|
|
|
|
res = torch.sparse.mm(x, self.weights) \
|
|
|
|
if x.is_sparse \
|
|
|
|
else torch.mm(x, self.weights)
|
|
|
|
res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
|
|
|
|
res = torch.cat(res)
|
|
|
|
res = torch.sparse.mm(self.adjacency_matrices, res) \
|
|
|
|
if self.adjacency_matrices.is_sparse \
|
|
|
|
else torch.mm(self.adjacency_matrices, res)
|
|
|
|
res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
|
|
|
|
if self.activation is not None:
|
|
|
|
res = self.activation(res)
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
class FastConvLayer(torch.nn.Module):
|
|
|
|