|
@@ -124,11 +124,14 @@ class DecagonLayer(Layer): |
|
|
self.next_layer_repr = defaultdict(list)
|
|
|
self.next_layer_repr = defaultdict(list)
|
|
|
|
|
|
|
|
|
for (nt_row, nt_col), relation_types in self.data.relation_types.items():
|
|
|
for (nt_row, nt_col), relation_types in self.data.relation_types.items():
|
|
|
|
|
|
row_convs = []
|
|
|
|
|
|
col_convs = []
|
|
|
|
|
|
|
|
|
for rel in relation_types:
|
|
|
for rel in relation_types:
|
|
|
conv = DropoutGraphConvActivation(self.input_dim[nt_col],
|
|
|
conv = DropoutGraphConvActivation(self.input_dim[nt_col],
|
|
|
self.output_dim[nt_row], rel.adjacency_matrix,
|
|
|
self.output_dim[nt_row], rel.adjacency_matrix,
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
self.next_layer_repr[nt_row].append((conv, nt_col))
|
|
|
|
|
|
|
|
|
row_convs.append(conv)
|
|
|
|
|
|
|
|
|
if nt_row == nt_col:
|
|
|
if nt_row == nt_col:
|
|
|
continue
|
|
|
continue
|
|
@@ -136,21 +139,27 @@ class DecagonLayer(Layer): |
|
|
conv = DropoutGraphConvActivation(self.input_dim[nt_row],
|
|
|
conv = DropoutGraphConvActivation(self.input_dim[nt_row],
|
|
|
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
|
|
|
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
self.next_layer_repr[nt_col].append((conv, nt_row))
|
|
|
|
|
|
|
|
|
col_convs.append(conv)
|
|
|
|
|
|
|
|
|
|
|
|
self.next_layer_repr[nt_row].append((row_convs, nt_col))
|
|
|
|
|
|
|
|
|
|
|
|
if nt_row == nt_col:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
self.next_layer_repr[nt_col].append((col_convs, nt_row))
|
|
|
|
|
|
|
|
|
def __call__(self):
|
|
|
def __call__(self):
|
|
|
prev_layer_repr = self.previous_layer()
|
|
|
prev_layer_repr = self.previous_layer()
|
|
|
next_layer_repr = [None] * len(self.data.node_types)
|
|
|
|
|
|
|
|
|
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
for i in range(len(self.data.node_types)):
|
|
|
for i in range(len(self.data.node_types)):
|
|
|
next_layer_repr[i] = [
|
|
|
|
|
|
conv(prev_layer_repr[neighbor_type]) \
|
|
|
|
|
|
for (conv, neighbor_type) in \
|
|
|
|
|
|
self.next_layer_repr[i]
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
for convs, neighbor_type in self.next_layer_repr[i]:
|
|
|
|
|
|
convs = [ conv(prev_layer_repr[neighbor_type]) \
|
|
|
|
|
|
for conv in convs ]
|
|
|
|
|
|
convs = sum(convs)
|
|
|
|
|
|
convs = torch.nn.functional.normalize(convs, p=2, dim=1)
|
|
|
|
|
|
next_layer_repr[i].append(convs)
|
|
|
next_layer_repr[i] = sum(next_layer_repr[i])
|
|
|
next_layer_repr[i] = sum(next_layer_repr[i])
|
|
|
next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
# next_layer_repr = list(map(sum, next_layer_repr))
|
|
|
|
|
|
return next_layer_repr
|
|
|
return next_layer_repr
|