| @@ -51,34 +51,23 @@ class DecagonLayer(torch.nn.Module): | |||
| self.build() | |||
| def build(self): | |||
| n = len(self.data.node_types) | |||
| rel_types = self.data.relation_types | |||
| self.next_layer_repr = [ [] for _ in range(n) ] | |||
| self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
| for node_type_row in range(n): | |||
| if node_type_row not in rel_types: | |||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
| if len(rels) == 0: | |||
| continue | |||
| for node_type_column in range(n): | |||
| if node_type_column not in rel_types[node_type_row]: | |||
| continue | |||
| rels = rel_types[node_type_row][node_type_column] | |||
| if len(rels) == 0: | |||
| continue | |||
| convolutions = [] | |||
| convolutions = [] | |||
| for r in rels: | |||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||
| self.keep_prob, self.rel_activation) | |||
| for r in rels: | |||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||
| self.keep_prob, self.rel_activation) | |||
| convolutions.append(conv) | |||
| convolutions.append(conv) | |||
| self.next_layer_repr[node_type_row].append( | |||
| Convolutions(node_type_column, convolutions)) | |||
| self.next_layer_repr[node_type_row].append( | |||
| Convolutions(node_type_column, convolutions)) | |||
| def __call__(self, prev_layer_repr): | |||
| next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
| @@ -7,6 +7,9 @@ | |||
| from collections import defaultdict | |||
| from dataclasses import dataclass | |||
| import torch | |||
| from typing import List, \ | |||
| Dict, \ | |||
| Tuple | |||
| @dataclass | |||
| @@ -25,9 +28,12 @@ class RelationType(object): | |||
| class Data(object): | |||
| node_types: List[NodeType] | |||
| relation_types: Dict[Tuple[int, int], List[RelationType]] | |||
| def __init__(self) -> None: | |||
| self.node_types = [] | |||
| self.relation_types = defaultdict(lambda: defaultdict(list)) | |||
| self.relation_types = defaultdict(list) | |||
| def add_node_type(self, name: str, count: int) -> None: | |||
| name = str(name) | |||
| @@ -73,14 +79,14 @@ class Data(object): | |||
| adjacency_matrix_backward is not None: | |||
| raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||
| self.relation_types[node_type_row][node_type_column].append( | |||
| self.relation_types[node_type_row, node_type_column].append( | |||
| RelationType(name, node_type_row, node_type_column, | |||
| adjacency_matrix, False)) | |||
| if node_type_row != node_type_column and two_way: | |||
| if adjacency_matrix_backward is None: | |||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| self.relation_types[node_type_column][node_type_row].append( | |||
| self.relation_types[node_type_column, node_type_row].append( | |||
| RelationType(name, node_type_column, node_type_row, | |||
| adjacency_matrix_backward, True)) | |||
| @@ -99,16 +105,15 @@ class Data(object): | |||
| s_1 = '' | |||
| count = 0 | |||
| for i in range(n): | |||
| for j in range(n): | |||
| if i not in self.relation_types or \ | |||
| j not in self.relation_types[i]: | |||
| for node_type_row in range(n): | |||
| for node_type_column in range(n): | |||
| if (node_type_row, node_type_column) not in self.relation_types: | |||
| continue | |||
| s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ | |||
| self.node_types[j].name + ':\n' | |||
| s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ | |||
| self.node_types[node_type_column].name + ':\n' | |||
| for r in self.relation_types[i][j]: | |||
| for r in self.relation_types[node_type_row, node_type_column]: | |||
| if r.is_autogenerated: | |||
| continue | |||
| s_1 += ' - ' + r.name + '\n' | |||
| @@ -118,16 +123,3 @@ class Data(object): | |||
| s += s_1 | |||
| return s.strip() | |||
| # n = sum(map(len, self.relation_types)) | |||
| # | |||
| # for i in range(n): | |||
| # for j in range(n): | |||
| # key = (i, j) | |||
| # if key not in self.relation_types: | |||
| # continue | |||
| # rels = self.relation_types[key] | |||
| # | |||
| # for r in rels: | |||
| # | |||
| # return s.strip() | |||
| @@ -44,37 +44,27 @@ class DecodeLayer(torch.nn.Module): | |||
| def build(self) -> None: | |||
| self.decoders = {} | |||
| n = len(self.data.node_types) | |||
| relation_types = self.data.relation_types | |||
| for node_type_row in range(n): | |||
| if node_type_row not in relation_types: | |||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
| if len(rels) == 0: | |||
| continue | |||
| for node_type_column in range(n): | |||
| if node_type_column not in relation_types[node_type_row]: | |||
| continue | |||
| rels = relation_types[node_type_row][node_type_column] | |||
| if len(rels) == 0: | |||
| continue | |||
| if isinstance(self.decoder_class, dict): | |||
| if (node_type_row, node_type_column) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||
| else: | |||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||
| self.data.node_types[node_type_row].name, | |||
| self.data.node_types[node_type_column].name)) | |||
| if isinstance(self.decoder_class, dict): | |||
| if (node_type_row, node_type_column) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||
| else: | |||
| decoder_class = self.decoder_class | |||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||
| self.data.node_types[node_type_row].name, | |||
| self.data.node_types[node_type_column].name)) | |||
| else: | |||
| decoder_class = self.decoder_class | |||
| self.decoders[node_type_row, node_type_column] = \ | |||
| decoder_class(self.input_dim[node_type_row], | |||
| num_relation_types = len(rels), | |||
| keep_prob = self.keep_prob, | |||
| activation = self.activation) | |||
| self.decoders[node_type_row, node_type_column] = \ | |||
| decoder_class(self.input_dim[node_type_row], | |||
| num_relation_types = len(rels), | |||
| keep_prob = self.keep_prob, | |||
| activation = self.activation) | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||
| res = {} | |||
| @@ -46,7 +46,7 @@ class PreparedRelationType(object): | |||
| @dataclass | |||
| class PreparedData(object): | |||
| node_types: List[NodeType] | |||
| relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] | |||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
| def train_val_test_split_edges(edges: torch.Tensor, | |||
| @@ -137,9 +137,9 @@ def prepare_training(data: Data) -> PreparedData: | |||
| if not isinstance(data, Data): | |||
| raise ValueError('data must be of class Data') | |||
| relation_types = defaultdict(lambda: defaultdict(list)) | |||
| for (node_type_row, node_type_column), rels in data.relation_types: | |||
| relation_types = defaultdict(list) | |||
| for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||
| for r in rels: | |||
| relation_types[node_type_row][node_type_column].append( | |||
| relation_types[node_type_row, node_type_column].append( | |||
| prep_relation_type(r)) | |||
| return PreparedData(data.node_types, relation_types) | |||
| @@ -86,7 +86,7 @@ def test_decagon_layer_04(): | |||
| in_layer = OneHotInputLayer(d) | |||
| multi_dgca = MultiDGCA([10], 32, | |||
| [r.adjacency_matrix for r in d.relation_types[0][0]], | |||
| [r.adjacency_matrix for r in d.relation_types[0, 0]], | |||
| keep_prob=1., activation=lambda x: x) | |||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d, | |||
| @@ -129,7 +129,7 @@ def test_decagon_layer_05(): | |||
| in_layer = OneHotInputLayer(d) | |||
| multi_dgca = MultiDGCA([100, 100], 32, | |||
| [r.adjacency_matrix for r in d.relation_types[0][0]], | |||
| [r.adjacency_matrix for r in d.relation_types[0, 0]], | |||
| keep_prob=1., activation=lambda x: x) | |||
| d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | |||