|
@@ -136,11 +136,6 @@ class FastGraphConv(torch.nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FastConvLayer(torch.nn.Module):
|
|
|
class FastConvLayer(torch.nn.Module):
|
|
|
adjacency_matrix: List[torch.Tensor]
|
|
|
|
|
|
adjacency_matrix_backward: List[torch.Tensor]
|
|
|
|
|
|
weight: List[torch.Tensor]
|
|
|
|
|
|
weight_backward: List[torch.Tensor]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
def __init__(self,
|
|
|
input_dim: List[int],
|
|
|
input_dim: List[int],
|
|
|
output_dim: List[int],
|
|
|
output_dim: List[int],
|
|
@@ -162,70 +157,86 @@ class FastConvLayer(torch.nn.Module): |
|
|
self.rel_activation = rel_activation
|
|
|
self.rel_activation = rel_activation
|
|
|
self.layer_activation = layer_activation
|
|
|
self.layer_activation = layer_activation
|
|
|
|
|
|
|
|
|
self.adjacency_matrix = None
|
|
|
|
|
|
self.adjacency_matrix_backward = None
|
|
|
|
|
|
self.weight = None
|
|
|
|
|
|
self.weight_backward = None
|
|
|
|
|
|
|
|
|
self.is_sparse = False
|
|
|
|
|
|
self.next_layer_repr = None
|
|
|
self.build()
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
def build(self):
|
|
|
def build(self):
|
|
|
self.adjacency_matrix = []
|
|
|
|
|
|
self.adjacency_matrix_backward = []
|
|
|
|
|
|
self.weight = []
|
|
|
|
|
|
self.weight_backward = []
|
|
|
|
|
|
|
|
|
self.next_layer_repr = torch.nn.ModuleList([
|
|
|
|
|
|
torch.nn.ModuleList() \
|
|
|
|
|
|
for _ in range(len(self.data.node_types))
|
|
|
|
|
|
])
|
|
|
for fam in self.data.relation_families:
|
|
|
for fam in self.data.relation_families:
|
|
|
adj_mat = [ rel.adjacency_matrix \
|
|
|
|
|
|
for rel in fam.relation_types \
|
|
|
|
|
|
if rel.adjacency_matrix is not None ]
|
|
|
|
|
|
adj_mat_back = [ rel.adjacency_matrix_backward \
|
|
|
|
|
|
for rel in fam.relation_types \
|
|
|
|
|
|
if rel.adjacency_matrix_backward is not None ]
|
|
|
|
|
|
weight = [ init_glorot(self.input_dim[fam.node_type_column],
|
|
|
|
|
|
self.output_dim[fam.node_type_row]) \
|
|
|
|
|
|
for _ in range(len(adj_mat)) ]
|
|
|
|
|
|
weight_back = [ init_glorot(self.input_dim[fam.node_type_column],
|
|
|
|
|
|
self.output_dim[fam.node_type_row]) \
|
|
|
|
|
|
for _ in range(len(adj_mat_back)) ]
|
|
|
|
|
|
adj_mat = torch.cat(adj_mat) \
|
|
|
|
|
|
if len(adj_mat) > 0 \
|
|
|
|
|
|
else None
|
|
|
|
|
|
adj_mat_back = torch.cat(adj_mat_back) \
|
|
|
|
|
|
if len(adj_mat_back) > 0 \
|
|
|
|
|
|
else None
|
|
|
|
|
|
self.adjacency_matrix.append(adj_mat)
|
|
|
|
|
|
self.adjacency_matrix_backward.append(adj_mat_back)
|
|
|
|
|
|
self.weight.append(weight)
|
|
|
|
|
|
self.weight_backward.append(weight_back)
|
|
|
|
|
|
|
|
|
self.build_family(fam)
|
|
|
|
|
|
|
|
|
|
|
|
def build_family(self, fam) -> None:
|
|
|
|
|
|
if fam.node_type_row == fam.node_type_column:
|
|
|
|
|
|
self.build_fam_one_node_type(fam)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.build_fam_two_node_types(fam)
|
|
|
|
|
|
|
|
|
|
|
|
def build_fam_one_node_type(self, fam) -> None:
|
|
|
|
|
|
adjacency_matrices = [
|
|
|
|
|
|
r.adjacency_matrix \
|
|
|
|
|
|
for r in fam.relation_types
|
|
|
|
|
|
]
|
|
|
|
|
|
conv = FastGraphConv(self.input_dim[fam.node_type_column],
|
|
|
|
|
|
self.output_dim[fam.node_type_row],
|
|
|
|
|
|
adjacency_matrices,
|
|
|
|
|
|
self.keep_prob,
|
|
|
|
|
|
self.rel_activation)
|
|
|
|
|
|
conv.input_node_type = fam.node_type_column
|
|
|
|
|
|
self.next_layer_repr[fam.node_type_row].append(conv)
|
|
|
|
|
|
|
|
|
|
|
|
def build_fam_two_node_types(self, fam) -> None:
|
|
|
|
|
|
adjacency_matrices = [
|
|
|
|
|
|
r.adjacency_matrix \
|
|
|
|
|
|
for r in fam.relation_types \
|
|
|
|
|
|
if r.adjacency_matrix is not None
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
adjacency_matrices_backward = [
|
|
|
|
|
|
r.adjacency_matrix_backward \
|
|
|
|
|
|
for r in fam.relation_types \
|
|
|
|
|
|
if r.adjacency_matrix_backward is not None
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
conv = FastGraphConv(self.input_dim[fam.node_type_column],
|
|
|
|
|
|
self.output_dim[fam.node_type_row],
|
|
|
|
|
|
adjacency_matrices,
|
|
|
|
|
|
self.keep_prob,
|
|
|
|
|
|
self.rel_activation)
|
|
|
|
|
|
|
|
|
|
|
|
conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
|
|
|
|
|
|
self.output_dim[fam.node_type_column],
|
|
|
|
|
|
adjacency_matrices_backward,
|
|
|
|
|
|
self.keep_prob,
|
|
|
|
|
|
self.rel_activation)
|
|
|
|
|
|
|
|
|
|
|
|
conv.input_node_type = fam.node_type_column
|
|
|
|
|
|
conv_backward.input_node_type = fam.node_type_row
|
|
|
|
|
|
|
|
|
|
|
|
self.next_layer_repr[fam.node_type_row].append(conv)
|
|
|
|
|
|
self.next_layer_repr[fam.node_type_column].append(conv_backward)
|
|
|
|
|
|
|
|
|
def forward(self, prev_layer_repr):
|
|
|
def forward(self, prev_layer_repr):
|
|
|
for i, fam in enumerate(self.data.relation_families):
|
|
|
|
|
|
repr_row = prev_layer_repr[fam.node_type_row]
|
|
|
|
|
|
repr_column = prev_layer_repr[fam.node_type_column]
|
|
|
|
|
|
|
|
|
|
|
|
adj_mat = self.adjacency_matrix[i]
|
|
|
|
|
|
adj_mat_back = self.adjacency_matrix_backward[i]
|
|
|
|
|
|
|
|
|
|
|
|
if adj_mat is not None:
|
|
|
|
|
|
x = dropout(repr_column, keep_prob=self.keep_prob)
|
|
|
|
|
|
x = torch.sparse.mm(x, self.weight[i]) \
|
|
|
|
|
|
if x.is_sparse \
|
|
|
|
|
|
else torch.mm(x, self.weight[i])
|
|
|
|
|
|
x = torch.sparse.mm(adj_mat, x) \
|
|
|
|
|
|
if adj_mat.is_sparse \
|
|
|
|
|
|
else torch.mm(adj_mat, x)
|
|
|
|
|
|
x = self.rel_activation(x)
|
|
|
|
|
|
x = x.view(len(fam.relation_types), len(repr_row), -1)
|
|
|
|
|
|
|
|
|
|
|
|
if adj_mat_back is not None:
|
|
|
|
|
|
x = dropout(repr_column, keep_prob=self.keep_prob)
|
|
|
|
|
|
x = torch.sparse.mm(x, self.weight_backward[i]) \
|
|
|
|
|
|
if x.is_sparse \
|
|
|
|
|
|
else torch.mm(x, self.weight_backward[i])
|
|
|
|
|
|
x = torch.sparse.mm(adj_mat_back, x) \
|
|
|
|
|
|
if adj_mat_back.is_sparse \
|
|
|
|
|
|
else torch.mm(adj_mat_back, x)
|
|
|
|
|
|
x = self.rel_activation(x)
|
|
|
|
|
|
x = x.view(len(fam.relation_types), len(repr_row), -1)
|
|
|
|
|
|
|
|
|
next_layer_repr = [ [] \
|
|
|
|
|
|
for _ in range(len(self.data.node_types)) ]
|
|
|
|
|
|
for output_node_type in range(len(self.data.node_types)):
|
|
|
|
|
|
for conv in self.next_layer_repr[output_node_type]:
|
|
|
|
|
|
rep = conv(prev_layer_repr[conv.input_node_type])
|
|
|
|
|
|
rep = torch.sum(rep, dim=0)
|
|
|
|
|
|
rep = torch.nn.functional.normalize(rep, p=2, dim=1)
|
|
|
|
|
|
next_layer_repr[output_node_type].append(rep)
|
|
|
|
|
|
if len(next_layer_repr[output_node_type]) == 0:
|
|
|
|
|
|
next_layer_repr[output_node_type] = \
|
|
|
|
|
|
torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
|
|
|
|
|
|
else:
|
|
|
|
|
|
next_layer_repr[output_node_type] = \
|
|
|
|
|
|
sum(next_layer_repr[output_node_type])
|
|
|
|
|
|
next_layer_repr[output_node_type] = \
|
|
|
|
|
|
self.layer_activation(next_layer_repr[output_node_type])
|
|
|
|
|
|
return next_layer_repr
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
@staticmethod
|
|
|
def _check_params(input_dim, output_dim, data, keep_prob,
|
|
|
def _check_params(input_dim, output_dim, data, keep_prob,
|
|
|