| @@ -51,34 +51,23 @@ class DecagonLayer(torch.nn.Module): | |||||
| self.build() | self.build() | ||||
| def build(self): | def build(self): | ||||
| n = len(self.data.node_types) | |||||
| rel_types = self.data.relation_types | |||||
| self.next_layer_repr = [ [] for _ in range(n) ] | |||||
| self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||||
| for node_type_row in range(n): | |||||
| if node_type_row not in rel_types: | |||||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||||
| if len(rels) == 0: | |||||
| continue | continue | ||||
| for node_type_column in range(n): | |||||
| if node_type_column not in rel_types[node_type_row]: | |||||
| continue | |||||
| rels = rel_types[node_type_row][node_type_column] | |||||
| if len(rels) == 0: | |||||
| continue | |||||
| convolutions = [] | |||||
| convolutions = [] | |||||
| for r in rels: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||||
| self.keep_prob, self.rel_activation) | |||||
| for r in rels: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||||
| self.keep_prob, self.rel_activation) | |||||
| convolutions.append(conv) | |||||
| convolutions.append(conv) | |||||
| self.next_layer_repr[node_type_row].append( | |||||
| Convolutions(node_type_column, convolutions)) | |||||
| self.next_layer_repr[node_type_row].append( | |||||
| Convolutions(node_type_column, convolutions)) | |||||
| def __call__(self, prev_layer_repr): | def __call__(self, prev_layer_repr): | ||||
| next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
| @@ -7,6 +7,9 @@ | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||
| import torch | import torch | ||||
| from typing import List, \ | |||||
| Dict, \ | |||||
| Tuple | |||||
| @dataclass | @dataclass | ||||
| @@ -25,9 +28,12 @@ class RelationType(object): | |||||
| class Data(object): | class Data(object): | ||||
| node_types: List[NodeType] | |||||
| relation_types: Dict[Tuple[int, int], List[RelationType]] | |||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.node_types = [] | self.node_types = [] | ||||
| self.relation_types = defaultdict(lambda: defaultdict(list)) | |||||
| self.relation_types = defaultdict(list) | |||||
| def add_node_type(self, name: str, count: int) -> None: | def add_node_type(self, name: str, count: int) -> None: | ||||
| name = str(name) | name = str(name) | ||||
| @@ -73,14 +79,14 @@ class Data(object): | |||||
| adjacency_matrix_backward is not None: | adjacency_matrix_backward is not None: | ||||
| raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | ||||
| self.relation_types[node_type_row][node_type_column].append( | |||||
| self.relation_types[node_type_row, node_type_column].append( | |||||
| RelationType(name, node_type_row, node_type_column, | RelationType(name, node_type_row, node_type_column, | ||||
| adjacency_matrix, False)) | adjacency_matrix, False)) | ||||
| if node_type_row != node_type_column and two_way: | if node_type_row != node_type_column and two_way: | ||||
| if adjacency_matrix_backward is None: | if adjacency_matrix_backward is None: | ||||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | ||||
| self.relation_types[node_type_column][node_type_row].append( | |||||
| self.relation_types[node_type_column, node_type_row].append( | |||||
| RelationType(name, node_type_column, node_type_row, | RelationType(name, node_type_column, node_type_row, | ||||
| adjacency_matrix_backward, True)) | adjacency_matrix_backward, True)) | ||||
| @@ -99,16 +105,15 @@ class Data(object): | |||||
| s_1 = '' | s_1 = '' | ||||
| count = 0 | count = 0 | ||||
| for i in range(n): | |||||
| for j in range(n): | |||||
| if i not in self.relation_types or \ | |||||
| j not in self.relation_types[i]: | |||||
| for node_type_row in range(n): | |||||
| for node_type_column in range(n): | |||||
| if (node_type_row, node_type_column) not in self.relation_types: | |||||
| continue | continue | ||||
| s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ | |||||
| self.node_types[j].name + ':\n' | |||||
| s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ | |||||
| self.node_types[node_type_column].name + ':\n' | |||||
| for r in self.relation_types[i][j]: | |||||
| for r in self.relation_types[node_type_row, node_type_column]: | |||||
| if r.is_autogenerated: | if r.is_autogenerated: | ||||
| continue | continue | ||||
| s_1 += ' - ' + r.name + '\n' | s_1 += ' - ' + r.name + '\n' | ||||
| @@ -118,16 +123,3 @@ class Data(object): | |||||
| s += s_1 | s += s_1 | ||||
| return s.strip() | return s.strip() | ||||
| # n = sum(map(len, self.relation_types)) | |||||
| # | |||||
| # for i in range(n): | |||||
| # for j in range(n): | |||||
| # key = (i, j) | |||||
| # if key not in self.relation_types: | |||||
| # continue | |||||
| # rels = self.relation_types[key] | |||||
| # | |||||
| # for r in rels: | |||||
| # | |||||
| # return s.strip() | |||||
| @@ -44,37 +44,27 @@ class DecodeLayer(torch.nn.Module): | |||||
| def build(self) -> None: | def build(self) -> None: | ||||
| self.decoders = {} | self.decoders = {} | ||||
| n = len(self.data.node_types) | |||||
| relation_types = self.data.relation_types | |||||
| for node_type_row in range(n): | |||||
| if node_type_row not in relation_types: | |||||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||||
| if len(rels) == 0: | |||||
| continue | continue | ||||
| for node_type_column in range(n): | |||||
| if node_type_column not in relation_types[node_type_row]: | |||||
| continue | |||||
| rels = relation_types[node_type_row][node_type_column] | |||||
| if len(rels) == 0: | |||||
| continue | |||||
| if isinstance(self.decoder_class, dict): | |||||
| if (node_type_row, node_type_column) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||||
| else: | |||||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||||
| self.data.node_types[node_type_row].name, | |||||
| self.data.node_types[node_type_column].name)) | |||||
| if isinstance(self.decoder_class, dict): | |||||
| if (node_type_row, node_type_column) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||||
| else: | else: | ||||
| decoder_class = self.decoder_class | |||||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||||
| self.data.node_types[node_type_row].name, | |||||
| self.data.node_types[node_type_column].name)) | |||||
| else: | |||||
| decoder_class = self.decoder_class | |||||
| self.decoders[node_type_row, node_type_column] = \ | |||||
| decoder_class(self.input_dim[node_type_row], | |||||
| num_relation_types = len(rels), | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.activation) | |||||
| self.decoders[node_type_row, node_type_column] = \ | |||||
| decoder_class(self.input_dim[node_type_row], | |||||
| num_relation_types = len(rels), | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.activation) | |||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | ||||
| res = {} | res = {} | ||||
| @@ -46,7 +46,7 @@ class PreparedRelationType(object): | |||||
| @dataclass | @dataclass | ||||
| class PreparedData(object): | class PreparedData(object): | ||||
| node_types: List[NodeType] | node_types: List[NodeType] | ||||
| relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] | |||||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
| def train_val_test_split_edges(edges: torch.Tensor, | def train_val_test_split_edges(edges: torch.Tensor, | ||||
| @@ -137,9 +137,9 @@ def prepare_training(data: Data) -> PreparedData: | |||||
| if not isinstance(data, Data): | if not isinstance(data, Data): | ||||
| raise ValueError('data must be of class Data') | raise ValueError('data must be of class Data') | ||||
| relation_types = defaultdict(lambda: defaultdict(list)) | |||||
| for (node_type_row, node_type_column), rels in data.relation_types: | |||||
| relation_types = defaultdict(list) | |||||
| for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||||
| for r in rels: | for r in rels: | ||||
| relation_types[node_type_row][node_type_column].append( | |||||
| relation_types[node_type_row, node_type_column].append( | |||||
| prep_relation_type(r)) | prep_relation_type(r)) | ||||
| return PreparedData(data.node_types, relation_types) | return PreparedData(data.node_types, relation_types) | ||||
| @@ -86,7 +86,7 @@ def test_decagon_layer_04(): | |||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| multi_dgca = MultiDGCA([10], 32, | multi_dgca = MultiDGCA([10], 32, | ||||
| [r.adjacency_matrix for r in d.relation_types[0][0]], | |||||
| [r.adjacency_matrix for r in d.relation_types[0, 0]], | |||||
| keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d, | d_layer = DecagonLayer(in_layer.output_dim, 32, d, | ||||
| @@ -129,7 +129,7 @@ def test_decagon_layer_05(): | |||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| multi_dgca = MultiDGCA([100, 100], 32, | multi_dgca = MultiDGCA([100, 100], 32, | ||||
| [r.adjacency_matrix for r in d.relation_types[0][0]], | |||||
| [r.adjacency_matrix for r in d.relation_types[0, 0]], | |||||
| keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | ||||