| @@ -21,6 +21,7 @@ | |||||
| import torch | import torch | ||||
| from .convole import SparseMultiDGCA | |||||
| class InputLayer(torch.nn.Module): | class InputLayer(torch.nn.Module): | ||||
| @@ -52,9 +53,48 @@ class InputLayer(torch.nn.Module): | |||||
| class DecagonLayer(torch.nn.Module): | class DecagonLayer(torch.nn.Module): | ||||
| def __init__(self, data, **kwargs): | |||||
| def __init__(self, data, | |||||
| input_dim, output_dim, | |||||
| keep_prob=1., | |||||
| rel_activation=lambda x: x, | |||||
| layer_activation=torch.nn.functional.relu, | |||||
| **kwargs): | |||||
| super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
| self.data = data | self.data = data | ||||
| self.input_dim = input_dim | |||||
| self.output_dim = output_dim | |||||
| self.keep_prob = keep_prob | |||||
| self.rel_activation = rel_activation | |||||
| self.layer_activation = layer_activation | |||||
| self.convolutions = None | |||||
| self.build() | |||||
| def build(self): | |||||
| self.convolutions = {} | |||||
| for key in self.data.relation_types.keys(): | |||||
| adjacency_matrices = \ | |||||
| self.data.get_adjacency_matrices(*key) | |||||
| self.convolutions[key] = SparseMultiDGCA(self.input_dim, | |||||
| self.output_dim, adjacency_matrices, | |||||
| self.keep_prob, self.rel_activation) | |||||
| def __call__(self, previous_layer): | |||||
| pass | |||||
| # for node_type_row, node_type_col in enumerate(self.data.node_ | |||||
| # if rt.node_type_row == i or rt.node_type_col == i: | |||||
| def __call__(self, prev_layer_repr): | |||||
| new_layer_repr = [] | |||||
| for i, nt in enumerate(self.data.node_types): | |||||
| new_repr = [] | |||||
| for key in self.data.relation_types.keys(): | |||||
| nt_row, nt_col = key | |||||
| if nt_row != i and nt_col != i: | |||||
| continue | |||||
| if nt_row == i: | |||||
| x = prev_layer_repr[nt_col] | |||||
| else: | |||||
| x = prev_layer_repr[nt_row] | |||||
| conv = self.convolutions[key] | |||||
| new_repr.append(conv(x)) | |||||
| new_repr = sum(new_repr) | |||||
| new_layer_repr.append(new_repr) | |||||
| return new_layer_repr | |||||