IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Reimplement FastConvLayer.

master
Stanislaw Adaszewski 3 years ago
parent
commit
aa3cc1f3ad
1 changed files with 74 additions and 63 deletions
  1. +74
    -63
      src/icosagon/fastconv.py

+ 74
- 63
src/icosagon/fastconv.py View File

@@ -136,11 +136,6 @@ class FastGraphConv(torch.nn.Module):
class FastConvLayer(torch.nn.Module):
adjacency_matrix: List[torch.Tensor]
adjacency_matrix_backward: List[torch.Tensor]
weight: List[torch.Tensor]
weight_backward: List[torch.Tensor]
def __init__(self,
input_dim: List[int],
output_dim: List[int],
@@ -162,70 +157,86 @@ class FastConvLayer(torch.nn.Module):
self.rel_activation = rel_activation
self.layer_activation = layer_activation
self.adjacency_matrix = None
self.adjacency_matrix_backward = None
self.weight = None
self.weight_backward = None
self.is_sparse = False
self.next_layer_repr = None
self.build()
def build(self):
self.adjacency_matrix = []
self.adjacency_matrix_backward = []
self.weight = []
self.weight_backward = []
self.next_layer_repr = torch.nn.ModuleList([
torch.nn.ModuleList() \
for _ in range(len(self.data.node_types))
])
for fam in self.data.relation_families:
adj_mat = [ rel.adjacency_matrix \
for rel in fam.relation_types \
if rel.adjacency_matrix is not None ]
adj_mat_back = [ rel.adjacency_matrix_backward \
for rel in fam.relation_types \
if rel.adjacency_matrix_backward is not None ]
weight = [ init_glorot(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row]) \
for _ in range(len(adj_mat)) ]
weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row]) \
for _ in range(len(adj_mat_back)) ]
adj_mat = torch.cat(adj_mat) \
if len(adj_mat) > 0 \
else None
adj_mat_back = torch.cat(adj_mat_back) \
if len(adj_mat_back) > 0 \
else None
self.adjacency_matrix.append(adj_mat)
self.adjacency_matrix_backward.append(adj_mat_back)
self.weight.append(weight)
self.weight_backward.append(weight_back)
self.build_family(fam)
def build_family(self, fam) -> None:
if fam.node_type_row == fam.node_type_column:
self.build_fam_one_node_type(fam)
else:
self.build_fam_two_node_types(fam)
def build_fam_one_node_type(self, fam) -> None:
adjacency_matrices = [
r.adjacency_matrix \
for r in fam.relation_types
]
conv = FastGraphConv(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row],
adjacency_matrices,
self.keep_prob,
self.rel_activation)
conv.input_node_type = fam.node_type_column
self.next_layer_repr[fam.node_type_row].append(conv)
def build_fam_two_node_types(self, fam) -> None:
adjacency_matrices = [
r.adjacency_matrix \
for r in fam.relation_types \
if r.adjacency_matrix is not None
]
adjacency_matrices_backward = [
r.adjacency_matrix_backward \
for r in fam.relation_types \
if r.adjacency_matrix_backward is not None
]
conv = FastGraphConv(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row],
adjacency_matrices,
self.keep_prob,
self.rel_activation)
conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
self.output_dim[fam.node_type_column],
adjacency_matrices_backward,
self.keep_prob,
self.rel_activation)
conv.input_node_type = fam.node_type_column
conv_backward.input_node_type = fam.node_type_row
self.next_layer_repr[fam.node_type_row].append(conv)
self.next_layer_repr[fam.node_type_column].append(conv_backward)
def forward(self, prev_layer_repr):
for i, fam in enumerate(self.data.relation_families):
repr_row = prev_layer_repr[fam.node_type_row]
repr_column = prev_layer_repr[fam.node_type_column]
adj_mat = self.adjacency_matrix[i]
adj_mat_back = self.adjacency_matrix_backward[i]
if adj_mat is not None:
x = dropout(repr_column, keep_prob=self.keep_prob)
x = torch.sparse.mm(x, self.weight[i]) \
if x.is_sparse \
else torch.mm(x, self.weight[i])
x = torch.sparse.mm(adj_mat, x) \
if adj_mat.is_sparse \
else torch.mm(adj_mat, x)
x = self.rel_activation(x)
x = x.view(len(fam.relation_types), len(repr_row), -1)
if adj_mat_back is not None:
x = dropout(repr_column, keep_prob=self.keep_prob)
x = torch.sparse.mm(x, self.weight_backward[i]) \
if x.is_sparse \
else torch.mm(x, self.weight_backward[i])
x = torch.sparse.mm(adj_mat_back, x) \
if adj_mat_back.is_sparse \
else torch.mm(adj_mat_back, x)
x = self.rel_activation(x)
x = x.view(len(fam.relation_types), len(repr_row), -1)
next_layer_repr = [ [] \
for _ in range(len(self.data.node_types)) ]
for output_node_type in range(len(self.data.node_types)):
for conv in self.next_layer_repr[output_node_type]:
rep = conv(prev_layer_repr[conv.input_node_type])
rep = torch.sum(rep, dim=0)
rep = torch.nn.functional.normalize(rep, p=2, dim=1)
next_layer_repr[output_node_type].append(rep)
if len(next_layer_repr[output_node_type]) == 0:
next_layer_repr[output_node_type] = \
torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
else:
next_layer_repr[output_node_type] = \
sum(next_layer_repr[output_node_type])
next_layer_repr[output_node_type] = \
self.layer_activation(next_layer_repr[output_node_type])
return next_layer_repr
@staticmethod
def _check_params(input_dim, output_dim, data, keep_prob,


Loading…
Cancel
Save