|
|
@@ -68,6 +68,33 @@ class DecagonLayer(torch.nn.Module): |
|
|
|
self.next_layer_repr[fam.node_type_row].append(
|
|
|
|
Convolutions(fam.node_type_column, convolutions))
|
|
|
|
|
|
|
|
# def build_fam_two_node_types_sym(self, fam) -> None:
|
|
|
|
# convolutions_row = torch.nn.ModuleList()
|
|
|
|
# convolutions_column = torch.nn.ModuleList()
|
|
|
|
#
|
|
|
|
# if self.input_dim[fam.node_type_column] != \
|
|
|
|
# self.input_dim[fam.node_type_row]:
|
|
|
|
# raise ValueError('input_dim for row and column must be equal for a symmetric family')
|
|
|
|
#
|
|
|
|
# if self.output_dim[fam.node_type_column] != \
|
|
|
|
# self.output_dim[fam.node_type_row]:
|
|
|
|
# raise ValueError('output_dim for row and column must be equal for a symmetric family')
|
|
|
|
#
|
|
|
|
# for r in fam.relation_types:
|
|
|
|
# assert r.adjacency_matrix is not None and \
|
|
|
|
# r.adjacency_matrix_backward is not None
|
|
|
|
# conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
|
|
|
|
# self.output_dim[fam.node_type_row], r.adjacency_matrix,
|
|
|
|
# self.keep_prob, self.rel_activation)
|
|
|
|
# convolutions_row.append(conv)
|
|
|
|
# convolutions_column.append(conv.clone(r.adjacency_matrix_backward))
|
|
|
|
#
|
|
|
|
# self.next_layer_repr[fam.node_type_row].append(
|
|
|
|
# Convolutions(fam.node_type_column, convolutions_row))
|
|
|
|
#
|
|
|
|
# self.next_layer_repr[fam.node_type_column].append(
|
|
|
|
# Convolutions(fam.node_type_row, convolutions_column))
|
|
|
|
|
|
|
|
def build_fam_two_node_types(self, fam) -> None:
|
|
|
|
convolutions_row = torch.nn.ModuleList()
|
|
|
|
convolutions_column = torch.nn.ModuleList()
|
|
|
@@ -91,6 +118,12 @@ class DecagonLayer(torch.nn.Module): |
|
|
|
self.next_layer_repr[fam.node_type_column].append(
|
|
|
|
Convolutions(fam.node_type_row, convolutions_column))
|
|
|
|
|
|
|
|
# def build_fam_two_node_types(self, fam) -> None:
|
|
|
|
# if fam.is_symmetric:
|
|
|
|
# self.build_fam_two_node_types_sym(fam)
|
|
|
|
# else:
|
|
|
|
# self.build_fam_two_node_types_asym(fam)
|
|
|
|
|
|
|
|
def build_family(self, fam) -> None:
|
|
|
|
if fam.node_type_row == fam.node_type_column:
|
|
|
|
self.build_fam_one_node_type(fam)
|
|
|
|