IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Przeglądaj źródła

Add test_fast_graph_conv_01() and test_fast_graph_conv_02().

master
Stanislaw Adaszewski 4 lat temu
rodzic
commit
1a303f1a51
2 zmienionych plików z 66 dodań i 57 usunięć
  1. +37
    -56
      src/icosagon/fastconv.py
  2. +29
    -1
      tests/icosagon/test_fastconv.py

+ 37
- 56
src/icosagon/fastconv.py Wyświetl plik

@@ -78,80 +78,61 @@ def _cat(matrices: List[torch.Tensor]):
class FastGraphConv(torch.nn.Module):
def __init__(self,
in_channels: List[int],
out_channels: List[int],
data: Union[Data, PreparedData],
relation_family: Union[RelationFamily, PreparedRelationFamily],
in_channels: int,
out_channels: int,
adjacency_matrices: List[torch.Tensor],
keep_prob: float = 1.,
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
**kwargs) -> None:
super().__init__(**kwargs)
in_channels = int(in_channels)
out_channels = int(out_channels)
if not isinstance(data, Data) and not isinstance(data, PreparedData):
raise TypeError('data must be an instance of Data or PreparedData')
if not isinstance(relation_family, RelationFamily) and \
not isinstance(relation_family, PreparedRelationFamily):
raise TypeError('relation_family must be an instance of RelationFamily or PreparedRelationFamily')
if not isinstance(adjacency_matrices, list):
raise TypeError('adjacency_matrices must be a list')
if len(adjacency_matrices) == 0:
raise ValueError('adjacency_matrices must not be empty')
if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
if not all(m.is_sparse for m in adjacency_matrices):
raise ValueError('adjacency_matrices elements must be sparse')
keep_prob = float(keep_prob)
if not isinstance(activation, types.FunctionType):
raise TypeError('activation must be a function')
n_nodes_row = data.node_types[relation_family.node_type_row].count
n_nodes_column = data.node_types[relation_family.node_type_column].count
self.in_channels = in_channels
self.out_channels = out_channels
self.data = data
self.relation_family = relation_family
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.weight = torch.cat([
init_glorot(in_channels, out_channels) \
for _ in range(len(relation_family.relation_types))
], dim=1)
self.num_row_nodes = len(adjacency_matrices[0])
self.num_relation_types = len(adjacency_matrices)
self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
self.weight_backward = torch.cat([
self.weights = torch.cat([
init_glorot(in_channels, out_channels) \
for _ in range(len(relation_family.relation_types))
for _ in range(self.num_relation_types)
], dim=1)
self.adjacency_matrix = _sparse_diag_cat([
rel.adjacency_matrix \
if rel.adjacency_matrix is not None \
else _sparse_coo_tensor([], [], size=(n_nodes_row, n_nodes_column)) \
for rel in relation_family.relation_types ])
self.adjacency_matrix_backward = _sparse_diag_cat([
rel.adjacency_matrix_backward \
if rel.adjacency_matrix_backward is not None \
else _sparse_coo_tensor([], [], size=(n_nodes_column, n_nodes_row)) \
for rel in relation_family.relation_types ])
def forward(self, prev_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
repr_row = prev_layer_repr[self.relation_family.node_type_row]
repr_column = prev_layer_repr[self.relation_family.node_type_column]
new_repr_row = torch.sparse.mm(repr_column, self.weight) \
if repr_column.is_sparse \
else torch.mm(repr_column, self.weight)
new_repr_row = torch.sparse.mm(self.adjacency_matrix, new_repr_row) \
if self.adjacency_matrix.is_sparse \
else torch.mm(self.adjacency_matrix, new_repr_row)
new_repr_row = new_repr_row.view(len(self.relation_family.relation_types),
len(repr_row), self.out_channels)
new_repr_column = torch.sparse.mm(repr_row, self.weight) \
if repr_row.is_sparse \
else torch.mm(repr_row, self.weight)
new_repr_column = torch.sparse.mm(self.adjacency_matrix_backward, new_repr_column) \
if self.adjacency_matrix_backward.is_sparse \
else torch.mm(self.adjacency_matrix_backward, new_repr_column)
new_repr_column = new_repr_column.view(len(self.relation_family.relation_types),
len(repr_column), self.out_channels)
return (new_repr_row, new_repr_column)
def forward(self, x) -> torch.Tensor:
if self.keep_prob < 1.:
x = dropout(x, self.keep_prob)
res = torch.sparse.mm(x, self.weights) \
if x.is_sparse \
else torch.mm(x, self.weights)
res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
res = torch.cat(res)
res = torch.sparse.mm(self.adjacency_matrices, res) \
if self.adjacency_matrices.is_sparse \
else torch.mm(self.adjacency_matrices, res)
res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
if self.activation is not None:
res = self.activation(res)
return res
class FastConvLayer(torch.nn.Module):


+ 29
- 1
tests/icosagon/test_fastconv.py Wyświetl plik

@@ -1,7 +1,10 @@
from icosagon.fastconv import _sparse_diag_cat, \
_cat
_cat, \
FastGraphConv
from icosagon.data import _equal
import torch
import pdb
import time
def test_sparse_diag_cat_01():
@@ -58,3 +61,28 @@ def test_cat_02():
assert res.shape == (35, 10)
assert res.is_sparse
assert torch.all(res.to_dense() == ground_truth)
def test_fast_graph_conv_01():
# pdb.set_trace()
adj_mats = [ torch.rand(10, 15).round().to_sparse() \
for _ in range(23) ]
fgc = FastGraphConv(32, 64, adj_mats)
in_repr = torch.rand(15, 32)
_ = fgc(in_repr)
def test_fast_graph_conv_02():
t = time.time()
m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
adj_mats = [ m for _ in range(1300) ]
print('Generating adj_mats took:', time.time() - t)
t = time.time()
fgc = FastGraphConv(32, 64, adj_mats)
print('FGC constructor took:', time.time() - t)
in_repr = torch.rand(2000, 32)
for _ in range(3):
t = time.time()
_ = fgc(in_repr)
print('FGC forward pass took:', time.time() - t)

Ładowanie…
Anuluj
Zapisz