IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Prepare to make Decagon layer work.

master
Stanislaw Adaszewski 3 years ago
parent
commit
2cdc76fac7
2 changed files with 28 additions and 21 deletions
  1. +23
    -16
      src/decagon_pytorch/layer.py
  2. +5
    -5
      tests/decagon_pytorch/test_layer.py

+ 23
- 16
src/decagon_pytorch/layer.py View File

@@ -22,51 +22,58 @@
import torch
from .convolve import SparseMultiDGCA
from .data import Data
from typing import List, Union
class InputLayer(torch.nn.Module):
def __init__(self, data, dimensionality=None, **kwargs):
class Layer(torch.nn.Module):
def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None:
super().__init__(**kwargs)
self.data = data
dimensionality = dimensionality or \
self.output_dim = output_dim
class InputLayer(Layer):
def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
output_dim = output_dim or \
list(map(lambda a: a.count, data.node_types))
if not isinstance(dimensionality, list):
dimensionality = [dimensionality,] * len(self.data.node_types)
self.dimensionality = dimensionality
if not isinstance(output_dim, list):
output_dim = [output_dim,] * len(data.node_types)
super().__init__(output_dim, **kwargs)
self.data = data
self.node_reps = None
self.build()
def build(self):
def build(self) -> None:
self.node_reps = []
for i, nt in enumerate(self.data.node_types):
reps = torch.rand(nt.count, self.dimensionality[i])
reps = torch.rand(nt.count, self.output_dim[i])
reps = torch.nn.Parameter(reps)
self.register_parameter('node_reps[%d]' % i, reps)
self.node_reps.append(reps)
def forward(self):
def forward(self) -> List[torch.nn.Parameter]:
return self.node_reps
def __repr__(self):
def __repr__(self) -> str:
s = ''
s += 'GNN input layer with dimensionality: %d\n' % self.dimensionality
s += 'GNN input layer with output_dim: %s\n' % self.output_dim
s += ' # of node types: %d\n' % len(self.data.node_types)
for nt in self.data.node_types:
s += ' - %s (%d)\n' % (nt.name, nt.count)
return s.strip()
class DecagonLayer(torch.nn.Module):
def __init__(self, data,
class DecagonLayer(Layer):
def __init__(self, data: Data,
input_dim, output_dim,
keep_prob=1.,
rel_activation=lambda x: x,
layer_activation=torch.nn.functional.relu,
**kwargs):
super().__init__(**kwargs)
super().__init__(output_dim, **kwargs)
self.data = data
self.input_dim = input_dim
self.output_dim = output_dim
self.keep_prob = keep_prob
self.rel_activation = rel_activation
self.layer_activation = layer_activation


+ 5
- 5
tests/decagon_pytorch/test_layer.py View File

@@ -36,12 +36,12 @@ def _some_data_with_interactions():
def test_input_layer_01():
d = _some_data()
for dimensionality in [32, 64, 128]:
layer = InputLayer(d, dimensionality)
assert layer.dimensionality[0] == dimensionality
for output_dim in [32, 64, 128]:
layer = InputLayer(d, output_dim)
assert layer.output_dim[0] == output_dim
assert len(layer.node_reps) == 2
assert layer.node_reps[0].shape == (1000, dimensionality)
assert layer.node_reps[1].shape == (100, dimensionality)
assert layer.node_reps[0].shape == (1000, output_dim)
assert layer.node_reps[1].shape == (100, output_dim)
assert layer.data == d


Loading…
Cancel
Save