IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Dummy run of DecagonLayer seems to work, that's something.

master
Stanislaw Adaszewski 3 years ago
parent
commit
69dd9a49e2
3 changed files with 59 additions and 7 deletions
  1. +2
    -0
      src/decagon_pytorch/data.py
  2. +49
    -7
      src/decagon_pytorch/layer.py
  3. +8
    -0
      tests/decagon_pytorch/test_layer.py

+ 2
- 0
src/decagon_pytorch/data.py View File

@@ -46,6 +46,8 @@ class Data(object):
if node_type_row >= n or node_type_column >= n:
raise ValueError('Node type index out of bounds, add node type first')
key = (node_type_row, node_type_column)
if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
adjacency_matrix = adjacency_matrix.to_sparse()
self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
# _ = self.decoder_types[(node_type_row, node_type_column)]


+ 49
- 7
src/decagon_pytorch/layer.py View File

@@ -30,9 +30,13 @@ from collections import defaultdict
class Layer(torch.nn.Module):
def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None:
def __init__(self,
output_dim: Union[int, List[int]],
is_sparse: bool,
**kwargs) -> None:
super().__init__(**kwargs)
self.output_dim = output_dim
self.is_sparse = is_sparse
class InputLayer(Layer):
@@ -42,7 +46,7 @@ class InputLayer(Layer):
if not isinstance(output_dim, list):
output_dim = [output_dim,] * len(data.node_types)
super().__init__(output_dim, **kwargs)
super().__init__(output_dim, is_sparse=False, **kwargs)
self.data = data
self.node_reps = None
self.build()
@@ -67,6 +71,34 @@ class InputLayer(Layer):
return s.strip()
class OneHotInputLayer(Layer):
def __init__(self, data: Data, **kwargs) -> None:
output_dim = [ a.count for a in data.node_types ]
super().__init__(output_dim, is_sparse=True, **kwargs)
self.data = data
self.node_reps = None
self.build()
def build(self) -> None:
self.node_reps = []
for i, nt in enumerate(self.data.node_types):
reps = torch.eye(nt.count).to_sparse()
reps = torch.nn.Parameter(reps)
self.register_parameter('node_reps[%d]' % i, reps)
self.node_reps.append(reps)
def forward(self) -> List[torch.nn.Parameter]:
return self.node_reps
def __repr__(self) -> str:
s = ''
s += 'One-hot GNN input layer\n'
s += ' # of node types: %d\n' % len(self.data.node_types)
for nt in self.data.node_types:
s += ' - %s (%d)\n' % (nt.name, nt.count)
return s.strip()
class DecagonLayer(Layer):
def __init__(self,
data: Data,
@@ -78,7 +110,7 @@ class DecagonLayer(Layer):
**kwargs):
if not isinstance(output_dim, list):
output_dim = [ output_dim ] * len(data.node_types)
super().__init__(output_dim, **kwargs)
super().__init__(output_dim, is_sparse=False, **kwargs)
self.data = data
self.previous_layer = previous_layer
self.input_dim = previous_layer.output_dim
@@ -98,6 +130,9 @@ class DecagonLayer(Layer):
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_row].append((conv, nt_col))
if nt_row == nt_col:
continue
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row],
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
self.keep_prob, self.rel_activation)
@@ -105,9 +140,16 @@ class DecagonLayer(Layer):
def __call__(self):
prev_layer_repr = self.previous_layer()
next_layer_repr = self.next_layer_repr
next_layer_repr = [None] * len(self.data.node_types)
print('next_layer_repr:', next_layer_repr)
for i in range(len(self.data.node_types)):
next_layer_repr[i] = map(lambda conv, neighbor_type: \
conv(prev_layer_repr[neighbor_type]), next_layer_repr[i])
next_layer_repr = list(map(sum, next_layer_repr))
next_layer_repr[i] = [
conv(prev_layer_repr[neighbor_type]) \
for (conv, neighbor_type) in \
self.next_layer_repr[i]
]
next_layer_repr[i] = sum(next_layer_repr[i])
print('next_layer_repr:', next_layer_repr)
# next_layer_repr = list(map(sum, next_layer_repr))
return next_layer_repr

+ 8
- 0
tests/decagon_pytorch/test_layer.py View File

@@ -1,4 +1,5 @@
from decagon_pytorch.layer import InputLayer, \
OneHotInputLayer, \
DecagonLayer
from decagon_pytorch.data import Data
import torch
@@ -74,3 +75,10 @@ def test_decagon_layer_01():
d = _some_data_with_interactions()
in_layer = InputLayer(d)
d_layer = DecagonLayer(d, in_layer, output_dim=32)
def test_decagon_layer_02():
d = _some_data_with_interactions()
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(d, in_layer, output_dim=32)
_ = d_layer() # dummy call

Loading…
Cancel
Save