diff --git a/docs/decagon-diagram.svg b/docs/decagon-diagram.svg
index ac8bfb2..8f8b73a 100644
--- a/docs/decagon-diagram.svg
+++ b/docs/decagon-diagram.svg
@@ -17,7 +17,227 @@
inkscape:version="0.92.3 (2405546, 2018-03-11)"
sodipodi:docname="decagon-diagram.svg">
+ id="defs2">
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ [
+ [
+ 0 0 11 0 00 1 0 [
+ [
+ 0 0 11 0 00 1 0 A
+ x
+ [
+ [
+ 0 0 11 0 00 1 0 x'
+ =
+
+ 1
+
+ 2
+
+ 3
+
+
+
+ [
+ [
+ 0 0 11 0 00 1 0 A
+
+ 1
+
+ 2
+
+ 3
+
+
+
+ [
+ [
+ 0 1 11 0 11 1 0 A2
+
+ 1
+
+ 2
+
+ 3
+
+
+
+ [
+ [
+ 0 1 00 0 11 0 0 A'
diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py
index 852433f..381d15d 100644
--- a/src/decagon_pytorch/data.py
+++ b/src/decagon_pytorch/data.py
@@ -19,13 +19,13 @@ class RelationType(object):
def get_adjacency_matrix(node_type_row, node_type_column):
if self.node_type_row == node_type_row and \
- self.node_type_column = node_type_column:
+ self.node_type_column == node_type_column:
return self.adjacency_matrix
elif self.node_type_row == node_type_column and \
self.node_type_column == node_type_row:
return self.adjacency_matrix.transpose(0, 1)
-
+
else:
raise ValueError('Specified row/column types do not correspond to this relation')
diff --git a/src/decagon_pytorch/layer.py b/src/decagon_pytorch/layer.py
index bb95dae..61911ff 100644
--- a/src/decagon_pytorch/layer.py
+++ b/src/decagon_pytorch/layer.py
@@ -21,11 +21,12 @@
import torch
-from .convolve import SparseMultiDGCA
+from .convolve import SparseDropoutGraphConvActivation
from .data import Data
from typing import List, \
Union, \
Callable
+from collections import defaultdict
class Layer(torch.nn.Module):
@@ -89,30 +90,46 @@ class DecagonLayer(Layer):
def build(self):
self.convolutions = {}
- for key in self.data.relation_types.keys():
+ for (node_type_row, node_type_column) in self.data.relation_types.keys():
adjacency_matrices = \
- self.data.get_adjacency_matrices(*key)
- self.convolutions[key] = SparseMultiDGCA(self.input_dim,
+ self.data.get_adjacency_matrices(node_type_row, node_type_column)
+ self.convolutions[node_type_row, node_type_column] = SparseMultiDGCA(self.input_dim,
self.output_dim, adjacency_matrices,
self.keep_prob, self.rel_activation)
# for node_type_row, node_type_col in enumerate(self.data.node_
# if rt.node_type_row == i or rt.node_type_col == i:
- def __call__(self, prev_layer_repr):
- new_layer_repr = []
- for i, nt in enumerate(self.data.node_types):
- new_repr = []
- for key in self.data.relation_types.keys():
- nt_row, nt_col = key
- if nt_row != i and nt_col != i:
- continue
- if nt_row == i:
- x = prev_layer_repr[nt_col]
- else:
- x = prev_layer_repr[nt_row]
- conv = self.convolutions[key]
- new_repr.append(conv(x))
- new_repr = sum(new_repr)
- new_layer_repr.append(new_repr)
- return new_layer_repr
+ def __call__(self):
+ prev_layer_repr = self.previous_layer()
+ next_layer_repr = defaultdict(list)
+
+ for (nt_row, nt_col), rel in self.data.relation_types.items():
+ conv = SparseDropoutGraphConvActivation(self.input_dim[nt_col],
+ self.output_dim[nt_row], rel.adjacency_matrix,
+ self.keep_prob, self.rel_activation)
+ next_layer_repr[nt_row].append(conv)
+
+ conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row],
+ self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
+ self.keep_prob, self.rel_activation)
+ next_layer_repr[nt_col].append(conv)
+
+ next_layer_repr = list(map(sum, next_layer_repr))
+ return next_layer_repr
+
+
+ #for i, nt in enumerate(self.data.node_types):
+ # new_repr = []
+ # for nt_row, nt_col in self.data.relation_types.keys():
+ # if nt_row != i and nt_col != i:
+ # continue
+ # if nt_row == i:
+ # x = prev_layer_repr[nt_col]
+ # else:
+ # x = prev_layer_repr[nt_row]
+ # conv = self.convolutions[key]
+ # new_repr.append(conv(x))
+ # new_repr = sum(new_repr)
+ # new_layer_repr.append(new_repr)
+ # return new_layer_repr