@@ -51,34 +51,23 @@ class DecagonLayer(torch.nn.Module): | |||||
self.build() | self.build() | ||||
def build(self): | def build(self): | ||||
n = len(self.data.node_types) | |||||
rel_types = self.data.relation_types | |||||
self.next_layer_repr = [ [] for _ in range(n) ] | |||||
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||||
for node_type_row in range(n): | |||||
if node_type_row not in rel_types: | |||||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||||
if len(rels) == 0: | |||||
continue | continue | ||||
for node_type_column in range(n): | |||||
if node_type_column not in rel_types[node_type_row]: | |||||
continue | |||||
rels = rel_types[node_type_row][node_type_column] | |||||
if len(rels) == 0: | |||||
continue | |||||
convolutions = [] | |||||
convolutions = [] | |||||
for r in rels: | |||||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
self.output_dim[node_type_row], r.adjacency_matrix, | |||||
self.keep_prob, self.rel_activation) | |||||
for r in rels: | |||||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
self.output_dim[node_type_row], r.adjacency_matrix, | |||||
self.keep_prob, self.rel_activation) | |||||
convolutions.append(conv) | |||||
convolutions.append(conv) | |||||
self.next_layer_repr[node_type_row].append( | |||||
Convolutions(node_type_column, convolutions)) | |||||
self.next_layer_repr[node_type_row].append( | |||||
Convolutions(node_type_column, convolutions)) | |||||
def __call__(self, prev_layer_repr): | def __call__(self, prev_layer_repr): | ||||
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
@@ -7,6 +7,9 @@ | |||||
from collections import defaultdict | from collections import defaultdict | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
import torch | import torch | ||||
from typing import List, \ | |||||
Dict, \ | |||||
Tuple | |||||
@dataclass | @dataclass | ||||
@@ -25,9 +28,12 @@ class RelationType(object): | |||||
class Data(object): | class Data(object): | ||||
node_types: List[NodeType] | |||||
relation_types: Dict[Tuple[int, int], List[RelationType]] | |||||
def __init__(self) -> None: | def __init__(self) -> None: | ||||
self.node_types = [] | self.node_types = [] | ||||
self.relation_types = defaultdict(lambda: defaultdict(list)) | |||||
self.relation_types = defaultdict(list) | |||||
def add_node_type(self, name: str, count: int) -> None: | def add_node_type(self, name: str, count: int) -> None: | ||||
name = str(name) | name = str(name) | ||||
@@ -73,14 +79,14 @@ class Data(object): | |||||
adjacency_matrix_backward is not None: | adjacency_matrix_backward is not None: | ||||
raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | raise ValueError('Relation between nodes of the same type must be expressed using a single matrix') | ||||
self.relation_types[node_type_row][node_type_column].append( | |||||
self.relation_types[node_type_row, node_type_column].append( | |||||
RelationType(name, node_type_row, node_type_column, | RelationType(name, node_type_row, node_type_column, | ||||
adjacency_matrix, False)) | adjacency_matrix, False)) | ||||
if node_type_row != node_type_column and two_way: | if node_type_row != node_type_column and two_way: | ||||
if adjacency_matrix_backward is None: | if adjacency_matrix_backward is None: | ||||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | ||||
self.relation_types[node_type_column][node_type_row].append( | |||||
self.relation_types[node_type_column, node_type_row].append( | |||||
RelationType(name, node_type_column, node_type_row, | RelationType(name, node_type_column, node_type_row, | ||||
adjacency_matrix_backward, True)) | adjacency_matrix_backward, True)) | ||||
@@ -99,16 +105,15 @@ class Data(object): | |||||
s_1 = '' | s_1 = '' | ||||
count = 0 | count = 0 | ||||
for i in range(n): | |||||
for j in range(n): | |||||
if i not in self.relation_types or \ | |||||
j not in self.relation_types[i]: | |||||
for node_type_row in range(n): | |||||
for node_type_column in range(n): | |||||
if (node_type_row, node_type_column) not in self.relation_types: | |||||
continue | continue | ||||
s_1 += ' - ' + self.node_types[i].name + ' -- ' + \ | |||||
self.node_types[j].name + ':\n' | |||||
s_1 += ' - ' + self.node_types[node_type_row].name + ' -- ' + \ | |||||
self.node_types[node_type_column].name + ':\n' | |||||
for r in self.relation_types[i][j]: | |||||
for r in self.relation_types[node_type_row, node_type_column]: | |||||
if r.is_autogenerated: | if r.is_autogenerated: | ||||
continue | continue | ||||
s_1 += ' - ' + r.name + '\n' | s_1 += ' - ' + r.name + '\n' | ||||
@@ -118,16 +123,3 @@ class Data(object): | |||||
s += s_1 | s += s_1 | ||||
return s.strip() | return s.strip() | ||||
# n = sum(map(len, self.relation_types)) | |||||
# | |||||
# for i in range(n): | |||||
# for j in range(n): | |||||
# key = (i, j) | |||||
# if key not in self.relation_types: | |||||
# continue | |||||
# rels = self.relation_types[key] | |||||
# | |||||
# for r in rels: | |||||
# | |||||
# return s.strip() |
@@ -44,37 +44,27 @@ class DecodeLayer(torch.nn.Module): | |||||
def build(self) -> None: | def build(self) -> None: | ||||
self.decoders = {} | self.decoders = {} | ||||
n = len(self.data.node_types) | |||||
relation_types = self.data.relation_types | |||||
for node_type_row in range(n): | |||||
if node_type_row not in relation_types: | |||||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||||
if len(rels) == 0: | |||||
continue | continue | ||||
for node_type_column in range(n): | |||||
if node_type_column not in relation_types[node_type_row]: | |||||
continue | |||||
rels = relation_types[node_type_row][node_type_column] | |||||
if len(rels) == 0: | |||||
continue | |||||
if isinstance(self.decoder_class, dict): | |||||
if (node_type_row, node_type_column) in self.decoder_class: | |||||
decoder_class = self.decoder_class[node_type_row, node_type_column] | |||||
elif (node_type_column, node_type_row) in self.decoder_class: | |||||
decoder_class = self.decoder_class[node_type_column, node_type_row] | |||||
else: | |||||
raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||||
self.data.node_types[node_type_row].name, | |||||
self.data.node_types[node_type_column].name)) | |||||
if isinstance(self.decoder_class, dict): | |||||
if (node_type_row, node_type_column) in self.decoder_class: | |||||
decoder_class = self.decoder_class[node_type_row, node_type_column] | |||||
elif (node_type_column, node_type_row) in self.decoder_class: | |||||
decoder_class = self.decoder_class[node_type_column, node_type_row] | |||||
else: | else: | ||||
decoder_class = self.decoder_class | |||||
raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||||
self.data.node_types[node_type_row].name, | |||||
self.data.node_types[node_type_column].name)) | |||||
else: | |||||
decoder_class = self.decoder_class | |||||
self.decoders[node_type_row, node_type_column] = \ | |||||
decoder_class(self.input_dim[node_type_row], | |||||
num_relation_types = len(rels), | |||||
keep_prob = self.keep_prob, | |||||
activation = self.activation) | |||||
self.decoders[node_type_row, node_type_column] = \ | |||||
decoder_class(self.input_dim[node_type_row], | |||||
num_relation_types = len(rels), | |||||
keep_prob = self.keep_prob, | |||||
activation = self.activation) | |||||
def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | ||||
res = {} | res = {} | ||||
@@ -46,7 +46,7 @@ class PreparedRelationType(object): | |||||
@dataclass | @dataclass | ||||
class PreparedData(object): | class PreparedData(object): | ||||
node_types: List[NodeType] | node_types: List[NodeType] | ||||
relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] | |||||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
def train_val_test_split_edges(edges: torch.Tensor, | def train_val_test_split_edges(edges: torch.Tensor, | ||||
@@ -137,9 +137,9 @@ def prepare_training(data: Data) -> PreparedData: | |||||
if not isinstance(data, Data): | if not isinstance(data, Data): | ||||
raise ValueError('data must be of class Data') | raise ValueError('data must be of class Data') | ||||
relation_types = defaultdict(lambda: defaultdict(list)) | |||||
for (node_type_row, node_type_column), rels in data.relation_types: | |||||
relation_types = defaultdict(list) | |||||
for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||||
for r in rels: | for r in rels: | ||||
relation_types[node_type_row][node_type_column].append( | |||||
relation_types[node_type_row, node_type_column].append( | |||||
prep_relation_type(r)) | prep_relation_type(r)) | ||||
return PreparedData(data.node_types, relation_types) | return PreparedData(data.node_types, relation_types) |
@@ -86,7 +86,7 @@ def test_decagon_layer_04(): | |||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
multi_dgca = MultiDGCA([10], 32, | multi_dgca = MultiDGCA([10], 32, | ||||
[r.adjacency_matrix for r in d.relation_types[0][0]], | |||||
[r.adjacency_matrix for r in d.relation_types[0, 0]], | |||||
keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
d_layer = DecagonLayer(in_layer.output_dim, 32, d, | d_layer = DecagonLayer(in_layer.output_dim, 32, d, | ||||
@@ -129,7 +129,7 @@ def test_decagon_layer_05(): | |||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
multi_dgca = MultiDGCA([100, 100], 32, | multi_dgca = MultiDGCA([100, 100], 32, | ||||
[r.adjacency_matrix for r in d.relation_types[0][0]], | |||||
[r.adjacency_matrix for r in d.relation_types[0, 0]], | |||||
keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | ||||