@@ -51,34 +51,23 @@ class DecagonLayer(torch.nn.Module): | |||
self.build() | |||
def build(self): | |||
n = len(self.data.node_types) | |||
rel_types = self.data.relation_types | |||
self.next_layer_repr = [ [] for _ in range(n) ] | |||
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
for node_type_row in range(n): | |||
if node_type_row not in rel_types: | |||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
if len(rels) == 0: | |||
continue | |||
for node_type_column in range(n): | |||
if node_type_column not in rel_types[node_type_row]: | |||
continue | |||
rels = rel_types[node_type_row][node_type_column] | |||
if len(rels) == 0: | |||
continue | |||
convolutions = [] | |||
convolutions = [] | |||
for r in rels: | |||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
self.output_dim[node_type_row], r.adjacency_matrix, | |||
self.keep_prob, self.rel_activation) | |||
for r in rels: | |||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
self.output_dim[node_type_row], r.adjacency_matrix, | |||
self.keep_prob, self.rel_activation) | |||
convolutions.append(conv) | |||
convolutions.append(conv) | |||
self.next_layer_repr[node_type_row].append( | |||
Convolutions(node_type_column, convolutions)) | |||
self.next_layer_repr[node_type_row].append( | |||
Convolutions(node_type_column, convolutions)) | |||
def __call__(self, prev_layer_repr): | |||
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
@@ -7,6 +7,9 @@ | |||
from collections import defaultdict | |||
from dataclasses import dataclass | |||
import torch | |||
from typing import List, \ | |||
Dict, \ | |||
Tuple | |||
@dataclass | |||
@@ -25,9 +28,12 @@ class RelationType(object): | |||
class Data(object): | |||
node_types: List[NodeType] | |||
relation_types: Dict[Tuple[int, int], List[RelationType]] | |||
def __init__(self) -> None: | |||
self.node_types = [] | |||
self.relation_types = defaultdict(lambda: defaultdict(list)) | |||
self.relation_types = defaultdict(list) | |||
def add_node_type(self, name: str, count: int) -> None: | |||
name = str(name) | |||
@@ -73,14 +79,14 @@ class Data(object): | |||
adjacency_matrix_backward is not None: | |||
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | |||
self.relation_types[node_type_row][node_type_column].append( | |||
self.relation_types[node_type_row, node_type_column].append( | |||
RelationType(name, node_type_row, node_type_column, | |||
adjacency_matrix, False)) | |||
if node_type_row != node_type_column and two_way: | |||
if adjacency_matrix_backward is None: | |||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
self.relation_types[node_type_column][node_type_row].append( | |||
self.relation_types[node_type_column, node_type_row].append( | |||
RelationType(name, node_type_column, node_type_row, | |||
adjacency_matrix_backward, True)) | |||
@@ -99,16 +105,15 @@ class Data(object): | |||
s_1 = '' | |||
count = 0 | |||
for i in range(n): | |||
for j in range(n): | |||
if i not in self.relation_types or \ | |||
j not in self.relation_types[i]: | |||
for node_type_row in range(n): | |||
for node_type_column in range(n): | |||
if (node_type_row, node_type_column) not in self.relation_types: | |||
continue | |||
s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ | |||
self.node_types[j].name + ':\n' | |||
s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ | |||
self.node_types[node_type_column].name + ':\n' | |||
for r in self.relation_types[i][j]: | |||
for r in self.relation_types[node_type_row, node_type_column]: | |||
if r.is_autogenerated: | |||
continue | |||
s_1 += ' - ' + r.name + '\n' | |||
@@ -118,16 +123,3 @@ class Data(object): | |||
s += s_1 | |||
return s.strip() | |||
# n = sum(map(len, self.relation_types)) | |||
# | |||
# for i in range(n): | |||
# for j in range(n): | |||
# key = (i, j) | |||
# if key not in self.relation_types: | |||
# continue | |||
# rels = self.relation_types[key] | |||
# | |||
# for r in rels: | |||
# | |||
# return s.strip() |
@@ -44,37 +44,27 @@ class DecodeLayer(torch.nn.Module): | |||
def build(self) -> None: | |||
self.decoders = {} | |||
n = len(self.data.node_types) | |||
relation_types = self.data.relation_types | |||
for node_type_row in range(n): | |||
if node_type_row not in relation_types: | |||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
if len(rels) == 0: | |||
continue | |||
for node_type_column in range(n): | |||
if node_type_column not in relation_types[node_type_row]: | |||
continue | |||
rels = relation_types[node_type_row][node_type_column] | |||
if len(rels) == 0: | |||
continue | |||
if isinstance(self.decoder_class, dict): | |||
if (node_type_row, node_type_column) in self.decoder_class: | |||
decoder_class = self.decoder_class[node_type_row, node_type_column] | |||
elif (node_type_column, node_type_row) in self.decoder_class: | |||
decoder_class = self.decoder_class[node_type_column, node_type_row] | |||
else: | |||
raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||
self.data.node_types[node_type_row].name, | |||
self.data.node_types[node_type_column].name)) | |||
if isinstance(self.decoder_class, dict): | |||
if (node_type_row, node_type_column) in self.decoder_class: | |||
decoder_class = self.decoder_class[node_type_row, node_type_column] | |||
elif (node_type_column, node_type_row) in self.decoder_class: | |||
decoder_class = self.decoder_class[node_type_column, node_type_row] | |||
else: | |||
decoder_class = self.decoder_class | |||
raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||
self.data.node_types[node_type_row].name, | |||
self.data.node_types[node_type_column].name)) | |||
else: | |||
decoder_class = self.decoder_class | |||
self.decoders[node_type_row, node_type_column] = \ | |||
decoder_class(self.input_dim[node_type_row], | |||
num_relation_types = len(rels), | |||
keep_prob = self.keep_prob, | |||
activation = self.activation) | |||
self.decoders[node_type_row, node_type_column] = \ | |||
decoder_class(self.input_dim[node_type_row], | |||
num_relation_types = len(rels), | |||
keep_prob = self.keep_prob, | |||
activation = self.activation) | |||
def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||
res = {} | |||
@@ -46,7 +46,7 @@ class PreparedRelationType(object): | |||
@dataclass | |||
class PreparedData(object): | |||
node_types: List[NodeType] | |||
relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] | |||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
def train_val_test_split_edges(edges: torch.Tensor, | |||
@@ -137,9 +137,9 @@ def prepare_training(data: Data) -> PreparedData: | |||
if not isinstance(data, Data): | |||
raise ValueError('data must be of class Data') | |||
relation_types = defaultdict(lambda: defaultdict(list)) | |||
for (node_type_row, node_type_column), rels in data.relation_types: | |||
relation_types = defaultdict(list) | |||
for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||
for r in rels: | |||
relation_types[node_type_row][node_type_column].append( | |||
relation_types[node_type_row, node_type_column].append( | |||
prep_relation_type(r)) | |||
return PreparedData(data.node_types, relation_types) |
@@ -86,7 +86,7 @@ def test_decagon_layer_04(): | |||
in_layer = OneHotInputLayer(d) | |||
multi_dgca = MultiDGCA([10], 32, | |||
[r.adjacency_matrix for r in d.relation_types[0][0]], | |||
[r.adjacency_matrix for r in d.relation_types[0, 0]], | |||
keep_prob=1., activation=lambda x: x) | |||
d_layer = DecagonLayer(in_layer.output_dim, 32, d, | |||
@@ -129,7 +129,7 @@ def test_decagon_layer_05(): | |||
in_layer = OneHotInputLayer(d) | |||
multi_dgca = MultiDGCA([100, 100], 32, | |||
[r.adjacency_matrix for r in d.relation_types[0][0]], | |||
[r.adjacency_matrix for r in d.relation_types[0, 0]], | |||
keep_prob=1., activation=lambda x: x) | |||
d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | |||