| @@ -85,51 +85,28 @@ class DecagonLayer(Layer): | |||
| self.keep_prob = keep_prob | |||
| self.rel_activation = rel_activation | |||
| self.layer_activation = layer_activation | |||
| self.convolutions = None | |||
| self.next_layer_repr = None | |||
| self.build() | |||
| def build(self): | |||
| self.convolutions = {} | |||
| for (node_type_row, node_type_column) in self.data.relation_types.keys(): | |||
| adjacency_matrices = \ | |||
| self.data.get_adjacency_matrices(node_type_row, node_type_column) | |||
| self.convolutions[node_type_row, node_type_column] = SparseMultiDGCA(self.input_dim, | |||
| self.output_dim, adjacency_matrices, | |||
| self.keep_prob, self.rel_activation) | |||
| # for node_type_row, node_type_col in enumerate(self.data.node_ | |||
| # if rt.node_type_row == i or rt.node_type_col == i: | |||
| def __call__(self): | |||
| prev_layer_repr = self.previous_layer() | |||
| next_layer_repr = defaultdict(list) | |||
| self.next_layer_repr = defaultdict(list) | |||
| for (nt_row, nt_col), rel in self.data.relation_types.items(): | |||
| conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col], | |||
| self.output_dim[nt_row], rel.adjacency_matrix, | |||
| self.keep_prob, self.rel_activation) | |||
| next_layer_repr[nt_row].append(conv) | |||
| self.next_layer_repr[nt_row].append((conv, nt_col)) | |||
| conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row], | |||
| self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1), | |||
| self.keep_prob, self.rel_activation) | |||
| next_layer_repr[nt_col].append(conv) | |||
| self.next_layer_repr[nt_col].append((conv, nt_row)) | |||
| def __call__(self): | |||
| prev_layer_repr = self.previous_layer() | |||
| next_layer_repr = self.next_layer_repr | |||
| for i in range(len(self.data.node_types)): | |||
| next_layer_repr[i] = map(lambda conv, neighbor_type: \ | |||
| conv(prev_layer_repr[neighbor_type]), next_layer_repr[i]) | |||
| next_layer_repr = list(map(sum, next_layer_repr)) | |||
| return next_layer_repr | |||
| #for i, nt in enumerate(self.data.node_types): | |||
| # new_repr = [] | |||
| # for nt_row, nt_col in self.data.relation_types.keys(): | |||
| # if nt_row != i and nt_col != i: | |||
| # continue | |||
| # if nt_row == i: | |||
| # x = prev_layer_repr[nt_col] | |||
| # else: | |||
| # x = prev_layer_repr[nt_row] | |||
| # conv = self.convolutions[key] | |||
| # new_repr.append(conv(x)) | |||
| # new_repr = sum(new_repr) | |||
| # new_layer_repr.append(new_repr) | |||
| # return new_layer_repr | |||