IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Work on FastGraphConv.

master
Stanislaw Adaszewski 3 years ago
parent
commit
e2ae97add3
1 changed files with 71 additions and 16 deletions
  1. +71
    -16
      src/icosagon/fastconv.py

+ 71
- 16
src/icosagon/fastconv.py View File

@@ -1,11 +1,14 @@
from typing import List, \
Union, \
Callable
from .data import Data
from .trainprep import PreparedData
from .data import Data, \
RelationFamily
from .trainprep import PreparedData, \
PreparedRelationFamily
import torch
from .weights import init_glorot
from .normalize import _sparse_coo_tensor
import types
def _sparse_diag_cat(matrices: List[torch.Tensor]):
@@ -75,28 +78,80 @@ def _cat(matrices: List[torch.Tensor]):
class FastGraphConv(torch.nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
adjacency_matrix: List[torch.Tensor],
**kwargs):
in_channels: List[int],
out_channels: List[int],
data: Union[Data, PreparedData],
relation_family: Union[RelationFamily, PreparedRelationFamily]
keep_prob: float = 1.,
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
**kwargs) -> None:
in_channels = int(in_channels)
out_channels = int(out_channels)
if not isinstance(data, Data) and not isinstance(data, PreparedData):
raise TypeError('data must be an instance of Data or PreparedData')
if not isinstance(relation_family, RelationFamily) and \
not isinstance(relation_family, PreparedRelationFamily):
raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
keep_prob = float(keep_prob)
if not isinstance(activation, types.FunctionType):
raise TypeError('activation must be a function')
n_nodes_row = data.node_types[relation_family.node_type_row].count
n_nodes_column = data.node_types[relation_family.node_type_column].count
self.in_channels = in_channels
self.out_channels = out_channels
self.data = data
self.relation_family = relation_family
self.keep_prob = keep_prob
self.activation = activation
self.weight = torch.cat([
init_glorot(in_channels, out_channels) \
for _ in adjacency_matrix
for _ in range(len(relation_family.relation_types))
], dim=1)
self.adjacency_matrix = _cat(adjacency_matrix)
def forward(self, x):
x = torch.sparse.mm(x, self.weight) \
if x.is_sparse \
else torch.mm(x, self.weight)
x = torch.sparse.mm(self.adjacency_matrix, x) \
if self.adjacency_matrix.is_sparse \
else torch.mm(self.adjacency_matrix, x)
return x
self.weight_backward = torch.cat([
init_glorot(in_channels, out_channels) \
for _ in range(len(relation_family.relation_types))
], dim=1)
self.adjacency_matrix = _sparse_diag_cat([
rel.adjacency_matrix \
if rel.adjacency_matrix is not None \
else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
for rel in relation_family.relation_types ])
self.adjacency_matrix_backward = _sparse_diag_cat([
rel.adjacency_matrix_backward \
if rel.adjacency_matrix_backward is not None \
else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
for rel in relation_family.relation_types ])
def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
repr_row = prev_layer_repr[self.relation_family.node_type_row]
repr_column = prev_layer_repr[self.relation_family.node_type_column]
new_repr_row = torch.sparse.mm(repr_column, self.weight) \
if repr_column.is_sparse \
else torch.mm(repr_column, self.weight)
new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
if self.adjacency_matrix.is_sparse \
else torch.mm(self.adjacency_matrix, new_repr_row)
new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
len(repr_row), self.out_channels)
new_repr_column = torch.sparse.mm(repr_row, self.weight) \
if repr_row.is_sparse \
else torch.mm(repr_row, self.weight)
new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
if self.adjacency_matrix_backward.is_sparse \
else torch.mm(self.adjacency_matrix_backward, new_repr_column)
new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
len(repr_column), self.out_channels)
return (new_repr_row, new_repr_column)
class FastConvLayer(torch.nn.Module):


Loading…
Cancel
Save