|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- from typing import List, \
- Union, \
- Callable
- from .data import Data, \
- RelationFamily
- from .trainprep import PreparedData, \
- PreparedRelationFamily
- import torch
- from .weights import init_glorot
- from .normalize import _sparse_coo_tensor
- import types
-
-
- def _sparse_diag_cat(matrices: List[torch.Tensor]):
- if len(matrices) == 0:
- raise ValueError('The list of matrices must be non-empty')
-
- if not all(m.is_sparse for m in matrices):
- raise ValueError('All matrices must be sparse')
-
- if not all(len(m.shape) == 2 for m in matrices):
- raise ValueError('All matrices must be 2D')
-
- indices = []
- values = []
- row_offset = 0
- col_offset = 0
-
- for m in matrices:
- ind = m._indices().clone()
- ind[0] += row_offset
- ind[1] += col_offset
- indices.append(ind)
- values.append(m._values())
- row_offset += m.shape[0]
- col_offset += m.shape[1]
-
- indices = torch.cat(indices, dim=1)
- values = torch.cat(values)
-
- return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
-
-
- def _cat(matrices: List[torch.Tensor]):
- if len(matrices) == 0:
- raise ValueError('Empty list passed to _cat()')
-
- n = sum(a.is_sparse for a in matrices)
- if n != 0 and n != len(matrices):
- raise ValueError('All matrices must have the same layout (dense or sparse)')
-
- if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
- raise ValueError('All matrices must have the same dimensions apart from dimension 0')
-
- if not matrices[0].is_sparse:
- return torch.cat(matrices)
-
- total_rows = sum(a.shape[0] for a in matrices)
- indices = []
- values = []
- row_offset = 0
-
- for a in matrices:
- ind = a._indices().clone()
- val = a._values()
- ind[0] += row_offset
- ind = ind.transpose(0, 1)
- indices.append(ind)
- values.append(val)
- row_offset += a.shape[0]
-
- indices = torch.cat(indices).transpose(0, 1)
- values = torch.cat(values)
-
- res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
- return res
-
-
- class FastGraphConv(torch.nn.Module):
- def __init__(self,
- in_channels: List[int],
- out_channels: List[int],
- data: Union[Data, PreparedData],
- relation_family: Union[RelationFamily, PreparedRelationFamily]
- keep_prob: float = 1.,
- acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- **kwargs) -> None:
-
- in_channels = int(in_channels)
- out_channels = int(out_channels)
- if not isinstance(data, Data) and not isinstance(data, PreparedData):
- raise TypeError('data must be an instance of Data or PreparedData')
- if not isinstance(relation_family, RelationFamily) and \
- not isinstance(relation_family, PreparedRelationFamily):
- raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
- keep_prob = float(keep_prob)
- if not isinstance(activation, types.FunctionType):
- raise TypeError('activation must be a function')
-
- n_nodes_row = data.node_types[relation_family.node_type_row].count
- n_nodes_column = data.node_types[relation_family.node_type_column].count
-
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.data = data
- self.relation_family = relation_family
- self.keep_prob = keep_prob
- self.activation = activation
-
- self.weight = torch.cat([
- init_glorot(in_channels, out_channels) \
- for _ in range(len(relation_family.relation_types))
- ], dim=1)
-
- self.weight_backward = torch.cat([
- init_glorot(in_channels, out_channels) \
- for _ in range(len(relation_family.relation_types))
- ], dim=1)
-
- self.adjacency_matrix = _sparse_diag_cat([
- rel.adjacency_matrix \
- if rel.adjacency_matrix is not None \
- else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
- for rel in relation_family.relation_types ])
-
- self.adjacency_matrix_backward = _sparse_diag_cat([
- rel.adjacency_matrix_backward \
- if rel.adjacency_matrix_backward is not None \
- else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
- for rel in relation_family.relation_types ])
-
- def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
- repr_row = prev_layer_repr[self.relation_family.node_type_row]
- repr_column = prev_layer_repr[self.relation_family.node_type_column]
-
- new_repr_row = torch.sparse.mm(repr_column, self.weight) \
- if repr_column.is_sparse \
- else torch.mm(repr_column, self.weight)
- new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
- if self.adjacency_matrix.is_sparse \
- else torch.mm(self.adjacency_matrix, new_repr_row)
- new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
- len(repr_row), self.out_channels)
-
- new_repr_column = torch.sparse.mm(repr_row, self.weight) \
- if repr_row.is_sparse \
- else torch.mm(repr_row, self.weight)
- new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
- if self.adjacency_matrix_backward.is_sparse \
- else torch.mm(self.adjacency_matrix_backward, new_repr_column)
- new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
- len(repr_column), self.out_channels)
-
- return (new_repr_row, new_repr_column)
-
-
- class FastConvLayer(torch.nn.Module):
- adjacency_matrix: List[torch.Tensor]
- adjacency_matrix_backward: List[torch.Tensor]
- weight: List[torch.Tensor]
- weight_backward: List[torch.Tensor]
-
- def __init__(self,
- input_dim: List[int],
- output_dim: List[int],
- data: Union[Data, PreparedData],
- keep_prob: float = 1.,
- rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
- layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
- **kwargs):
-
- super().__init__(**kwargs)
-
- self._check_params(input_dim, output_dim, data, keep_prob,
- rel_activation, layer_activation)
-
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.data = data
- self.keep_prob = keep_prob
- self.rel_activation = rel_activation
- self.layer_activation = layer_activation
-
- self.adjacency_matrix = None
- self.adjacency_matrix_backward = None
- self.weight = None
- self.weight_backward = None
- self.build()
-
- def build(self):
- self.adjacency_matrix = []
- self.adjacency_matrix_backward = []
- self.weight = []
- self.weight_backward = []
- for fam in self.data.relation_families:
- adj_mat = [ rel.adjacency_matrix \
- for rel in fam.relation_types \
- if rel.adjacency_matrix is not None ]
- adj_mat_back = [ rel.adjacency_matrix_backward \
- for rel in fam.relation_types \
- if rel.adjacency_matrix_backward is not None ]
- weight = [ init_glorot(self.input_dim[fam.node_type_column],
- self.output_dim[fam.node_type_row]) \
- for _ in range(len(adj_mat)) ]
- weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
- self.output_dim[fam.node_type_row]) \
- for _ in range(len(adj_mat_back)) ]
- adj_mat = torch.cat(adj_mat) \
- if len(adj_mat) > 0 \
- else None
- adj_mat_back = torch.cat(adj_mat_back) \
- if len(adj_mat_back) > 0 \
- else None
- self.adjacency_matrix.append(adj_mat)
- self.adjacency_matrix_backward.append(adj_mat_back)
- self.weight.append(weight)
- self.weight_back.append(weight_back)
-
- def forward(self, prev_layer_repr):
- for i, fam in enumerate(self.data.relation_families):
- repr_row = prev_layer_repr[fam.node_type_row]
- repr_column = prev_layer_repr[fam.node_type_column]
-
- adj_mat = self.adjacency_matrix[i]
- adj_mat_back = self.adjacency_matrix_backward[i]
-
- if adj_mat is not None:
- x = dropout(repr_column, keep_prob=self.keep_prob)
- x = torch.sparse.mm(x, self.weight[i]) \
- if x.is_sparse \
- else torch.mm(x, self.weight[i])
- x = torch.sparse.mm(adj_mat, repr_row) \
- if adj_mat.is_sparse \
- else torch.mm(adj_mat, repr_row)
- x = self.rel_activation(x)
- x = x.view(len(fam.relation_types), len(repr_row), -1)
-
- if adj_mat_back is not None:
- x = torch.sparse.mm(adj_mat_back, repr_row) \
- if adj_mat_back.is_sparse \
- else torch.mm(adj_mat_back, repr_row)
-
- @staticmethod
- def _check_params(input_dim, output_dim, data, keep_prob,
- rel_activation, layer_activation):
-
- if not isinstance(input_dim, list):
- raise ValueError('input_dim must be a list')
-
- if not output_dim:
- raise ValueError('output_dim must be specified')
-
- if not isinstance(output_dim, list):
- output_dim = [output_dim] * len(data.node_types)
-
- if not isinstance(data, Data) and not isinstance(data, PreparedData):
- raise ValueError('data must be of type Data or PreparedData')
|