diff --git a/.gitignore b/.gitignore
index c20c2ab..63adf45 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,11 @@
__pycache__
-
+.cache/
+.coverage
+/docs/icosagon/*.dot
+/docs/icosagon/*.png
+/experiments/decagon_run/profiler_results
+/experiments/decagon_run_effcat/profiler_results
+/src/torch_stablesort/dist
+/src/torch_stablesort/build
+/src/torch_stablesort/torch_stablesort.egg-info
+a.out
diff --git a/README.md b/README.md
index ab1c38f..a1c01f7 100644
--- a/README.md
+++ b/README.md
@@ -9,9 +9,14 @@ settings.
Decagon-PyTorch is a PyTorch reimplementation of the algorithm.
+## Citing
+
+If you use this code in your research please cite this repository as:
+
+Adaszewski S. (2020) https://code.adared.ch/sadaszewski/decagon-pytorch
+
## References
1. Zitnik, M., Agrawal, M., & Leskovec, J. (2018).
[Modeling polypharmacy side effects with graph convolutional networks](https://academic.oup.com/bioinformatics/article/34/13/i457/5045770)
Bioinformatics, 34(13), i457-i466.
-
diff --git a/decagon_pytorch/__init__.py b/decagon_pytorch/__init__.py
deleted file mode 100644
index f628a28..0000000
--- a/decagon_pytorch/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .weights import *
-from .convolve import *
-from .model import *
diff --git a/decagon_pytorch/dropout.py b/decagon_pytorch/dropout.py
deleted file mode 100644
index 3162572..0000000
--- a/decagon_pytorch/dropout.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-
-
-def dropout_sparse(x, keep_prob):
- """Dropout for sparse tensors.
- """
- x = x.coalesce()
- i = x._indices()
- v = x._values()
- size = x.size()
-
- n = keep_prob + torch.rand(len(v))
- n = torch.floor(n).to(torch.bool)
- i = i[:,n]
- v = v[n]
- x = torch.sparse_coo_tensor(i, v, size=size)
-
- return x * (1./keep_prob)
diff --git a/decagon_pytorch/weights.py b/decagon_pytorch/weights.py
deleted file mode 100644
index 305a70f..0000000
--- a/decagon_pytorch/weights.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import torch
-import numpy as np
-
-
-def init_glorot(input_dim, output_dim):
- """Create a weight variable with Glorot & Bengio (AISTATS 2010)
- initialization.
- """
- init_range = np.sqrt(6.0 / (input_dim + output_dim))
- initial = -init_range + 2 * init_range * \
- torch.rand(( input_dim, output_dim ), dtype=torch.float32)
- initial = initial.requires_grad_(True)
- return initial
diff --git a/docker/mlflow-tracking-server/Dockerfile b/docker/mlflow-tracking-server/Dockerfile
new file mode 100644
index 0000000..735a0d0
--- /dev/null
+++ b/docker/mlflow-tracking-server/Dockerfile
@@ -0,0 +1,33 @@
+FROM debian:latest
+
+RUN apt-get update && \
+ apt-get install -y python3-pip \
+ python3-pandas \
+ python3-alembic \
+ python3-sqlalchemy \
+ python3-yaml \
+ python3-flask \
+ python3-gunicorn \
+ python3-protobuf \
+ python3-urllib3 \
+ python3-certifi \
+ python3-idna \
+ python3-requests \
+# python3-docker \
+ python3-smmap \
+ python3-gitdb \
+ python3-git \
+ python3-sqlparse \
+ python3-oauthlib \
+ python3-requests-oauthlib \
+ python3-isodate \
+# python3-msrest \
+ python3-prometheus-client \
+ python3-cloudpickle \
+ python3-tabulate && \
+ pip3 install mlflow
+
+# RUN apk add --no-cache gcc gfortran libgfortran musl-dev python3 py3
+#pip python3-dev py3-numpy libffi libffi-dev && \
+# pip3 install mlflow && \
+# apk del musl-dev python3-dev gcc gfortran py3-pip libffi-dev
diff --git a/docker/postgresql/Dockerfile b/docker/postgresql/Dockerfile
new file mode 100644
index 0000000..137526d
--- /dev/null
+++ b/docker/postgresql/Dockerfile
@@ -0,0 +1,6 @@
+FROM alpine:latest
+
+RUN apk add postgresql
+
+RUN mkdir /data && \
+ chown postgres:postgres /data
diff --git a/docker/postgresql/docker-init.sh b/docker/postgresql/docker-init.sh
new file mode 100644
index 0000000..a69650a
--- /dev/null
+++ b/docker/postgresql/docker-init.sh
@@ -0,0 +1,5 @@
+#!/bin/sh
+
+if [ ! -f /data/PG_VERSION ]; then
+ initdb -D /data --pwfile=/superuser_password
+fi
diff --git a/docs/cumcount.svg b/docs/cumcount.svg
new file mode 100644
index 0000000..4da5428
--- /dev/null
+++ b/docs/cumcount.svg
@@ -0,0 +1,2074 @@
+
+
+
+
diff --git a/docs/decagon-diagram.svg b/docs/decagon-diagram.svg
new file mode 100644
index 0000000..aafe9ba
--- /dev/null
+++ b/docs/decagon-diagram.svg
@@ -0,0 +1,2081 @@
+
+
+
+
diff --git a/docs/icosagon-classes.svg b/docs/icosagon-classes.svg
new file mode 100644
index 0000000..fea7692
--- /dev/null
+++ b/docs/icosagon-classes.svg
@@ -0,0 +1,996 @@
+
+
+
+
diff --git a/docs/icosagon-hierarchy.svg b/docs/icosagon-hierarchy.svg
new file mode 100644
index 0000000..ed26cdd
--- /dev/null
+++ b/docs/icosagon-hierarchy.svg
@@ -0,0 +1,334 @@
+
+
+
+
diff --git a/docs/icosagon-reltype-rules.svg b/docs/icosagon-reltype-rules.svg
new file mode 100644
index 0000000..bda866a
--- /dev/null
+++ b/docs/icosagon-reltype-rules.svg
@@ -0,0 +1,636 @@
+
+
+
+
diff --git a/docs/matrix-multiply.svg b/docs/matrix-multiply.svg
new file mode 100644
index 0000000..a2af2a5
--- /dev/null
+++ b/docs/matrix-multiply.svg
@@ -0,0 +1,1462 @@
+
+
+
+
diff --git a/docs/nodes-involved.svg b/docs/nodes-involved.svg
new file mode 100644
index 0000000..045d85e
--- /dev/null
+++ b/docs/nodes-involved.svg
@@ -0,0 +1,438 @@
+
+
+
+
diff --git a/docs/required-vertices-per-layer.svg b/docs/required-vertices-per-layer.svg
new file mode 100644
index 0000000..f7de037
--- /dev/null
+++ b/docs/required-vertices-per-layer.svg
@@ -0,0 +1,1463 @@
+
+
+
+
diff --git a/docs/train-val-test-diagram.svg b/docs/train-val-test-diagram.svg
new file mode 100644
index 0000000..90428c0
--- /dev/null
+++ b/docs/train-val-test-diagram.svg
@@ -0,0 +1,1445 @@
+
+
+
+
diff --git a/experiments/decagon_run/decagon_run.py b/experiments/decagon_run/decagon_run.py
new file mode 100644
index 0000000..4093e08
--- /dev/null
+++ b/experiments/decagon_run/decagon_run.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python3
+
+from icosagon.data import Data
+from icosagon.trainprep import TrainValTest, \
+ prepare_training
+from icosagon.model import Model
+from icosagon.trainloop import TrainLoop
+import os
+import pandas as pd
+from bisect import bisect_left
+import torch
+import sys
+
+
+def index(a, x):
+ i = bisect_left(a, x)
+ if i != len(a) and a[i] == x:
+ return i
+ raise ValueError
+
+
+def load_data(dev):
+ path = '/pstore/data/data_science/ref/decagon'
+ df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
+ df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
+ df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
+ df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
+ df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
+ df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
+ lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
+ for nam in lst:
+ print(f'len({nam}): {len(locals()[nam])}')
+ print(f'{nam}.columns: {locals()[nam].columns}')
+
+ genes = set()
+ genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
+ .union(df_tgtall['Gene']).union(df_tgt['Gene'])
+ genes = sorted(genes)
+ print('len(genes):', len(genes))
+
+ drugs = set()
+ drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
+ .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
+ drugs = sorted(drugs)
+ print('len(drugs):', len(drugs))
+
+ data = Data()
+ data.add_node_type('Gene', len(genes))
+ data.add_node_type('Drug', len(drugs))
+
+ print('Preparing PPI...')
+ print('Indexing rows...')
+ rows = [index(genes, g) for g in df_ppi['Gene 1']]
+ print('Indexing cols...')
+ cols = [index(genes, g) for g in df_ppi['Gene 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ print('indices.shape:', indices.shape, 'values.shape:', values.shape)
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2,
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ print('adj_mat created')
+ fam = data.add_relation_family('PPI', 0, 0, True)
+ rel = fam.add_relation_type('PPI', adj_mat)
+ print('OK')
+
+ print('Preparing Drug-Gene (Target) edges...')
+ rows = [index(drugs, d) for d in df_tgtall['STITCH']]
+ cols = [index(genes, g) for g in df_tgtall['Gene']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)),
+ device=dev)
+ fam = data.add_relation_family('Drug-Gene (Target)', 1, 0, True)
+ rel = fam.add_relation_type('Drug-Gene (Target)', adj_mat)
+ print('OK')
+
+ print('Preparing Drug-Drug (Side Effect) edges...')
+ fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
+ print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
+ for eff, df in df_combo.groupby('Polypharmacy Side Effect'):
+ sys.stdout.write('.') # print(eff, '...')
+ sys.stdout.flush()
+ rows = [index(drugs, d) for d in df['STITCH 1']]
+ cols = [index(drugs, d) for d in df['STITCH 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)),
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ rel = fam.add_relation_type(df['Polypharmacy Side Effect'], adj_mat)
+ print()
+ print('OK')
+
+ return data
+
+
+def _wrap(obj, method_name):
+ orig_fn = getattr(obj, method_name)
+ def fn(*args, **kwargs):
+ print(f'{method_name}() :: ENTER')
+ res = orig_fn(*args, **kwargs)
+ print(f'{method_name}() :: EXIT')
+ return res
+ setattr(obj, method_name, fn)
+
+
+def main():
+ dev = torch.device('cpu')
+ data = load_data(dev)
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ _wrap(Model, 'build')
+ model = Model(prep_d)
+ model = model.to(dev)
+ # model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1'])
+ _wrap(TrainLoop, 'build')
+ _wrap(TrainLoop, 'run_epoch')
+ loop = TrainLoop(model, batch_size=512, shuffle=True)
+ loop.run_epoch()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/decagon_run_effcat/decagon_run_effcat.py b/experiments/decagon_run_effcat/decagon_run_effcat.py
new file mode 100644
index 0000000..afe65ec
--- /dev/null
+++ b/experiments/decagon_run_effcat/decagon_run_effcat.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+
+from icosagon.data import Data
+from icosagon.trainprep import TrainValTest, \
+ prepare_training
+from icosagon.model import Model
+from icosagon.trainloop import TrainLoop
+import os
+import pandas as pd
+from bisect import bisect_left
+import torch
+import sys
+
+
+def index(a, x):
+ i = bisect_left(a, x)
+ if i != len(a) and a[i] == x:
+ return i
+ raise ValueError
+
+
+def load_data(dev):
+ path = '/pstore/data/data_science/ref/decagon'
+ df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
+ df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
+ df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
+ df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
+ df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
+ df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
+ lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
+ for nam in lst:
+ print(f'len({nam}): {len(locals()[nam])}')
+ print(f'{nam}.columns: {locals()[nam].columns}')
+
+ genes = set()
+ genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
+ .union(df_tgtall['Gene']).union(df_tgt['Gene'])
+ genes = sorted(genes)
+ print('len(genes):', len(genes))
+
+ drugs = set()
+ drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
+ .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
+ drugs = sorted(drugs)
+ print('len(drugs):', len(drugs))
+
+ data = Data()
+ data.add_node_type('Gene', len(genes))
+ data.add_node_type('Drug', len(drugs))
+
+ print('Preparing PPI...')
+ print('Indexing rows...')
+ rows = [index(genes, g) for g in df_ppi['Gene 1']]
+ print('Indexing cols...')
+ cols = [index(genes, g) for g in df_ppi['Gene 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ print('indices.shape:', indices.shape, 'values.shape:', values.shape)
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2,
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ print('adj_mat created')
+ fam = data.add_relation_family('PPI', 0, 0, True)
+ rel = fam.add_relation_type('PPI', adj_mat)
+ print('OK')
+
+ print('Preparing Drug-Gene (Target) edges...')
+ rows = [index(drugs, d) for d in df_tgtall['STITCH']]
+ cols = [index(genes, g) for g in df_tgtall['Gene']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)),
+ device=dev)
+ fam = data.add_relation_family('Drug-Gene (Target)', 1, 0, True)
+ rel = fam.add_relation_type('Drug-Gene (Target)', adj_mat)
+ print('OK')
+
+ df_combo_effcat = df_combo.merge(df_effcat, left_on='Polypharmacy Side Effect', right_on='Side Effect')
+ disease_classes = []
+
+ print('Preparing Drug-Drug (Side Effect) edges...')
+ fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
+ print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
+ for discls, df in df_combo_effcat.groupby('Disease Class'):
+ disease_classes.append(discls)
+ sys.stdout.write('.') # print(eff, '...')
+ sys.stdout.flush()
+ rows = [index(drugs, d) for d in df['STITCH 1']]
+ cols = [index(drugs, d) for d in df['STITCH 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)),
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ rel = fam.add_relation_type(df['Polypharmacy Side Effect'], adj_mat)
+ print()
+ print('len(disease_classes):', len(disease_classes))
+ print('OK')
+
+ return data
+
+
+def _wrap(obj, method_name):
+ orig_fn = getattr(obj, method_name)
+ def fn(*args, **kwargs):
+ print(f'{method_name}() :: ENTER')
+ res = orig_fn(*args, **kwargs)
+ print(f'{method_name}() :: EXIT')
+ return res
+ setattr(obj, method_name, fn)
+
+
+def main():
+ dev = torch.device('cuda:0')
+ data = load_data(dev)
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ _wrap(Model, 'build')
+ model = Model(prep_d)
+ model = model.to(dev)
+ # model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1'])
+ _wrap(TrainLoop, 'build')
+ _wrap(TrainLoop, 'run_epoch')
+ loop = TrainLoop(model, batch_size=512, shuffle=True)
+ loop.run_epoch()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/triacontagon_run/triacontagon_run.py b/experiments/triacontagon_run/triacontagon_run.py
new file mode 100644
index 0000000..cfa81f0
--- /dev/null
+++ b/experiments/triacontagon_run/triacontagon_run.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+
+from triacontagon.data import Data
+from triacontagon.split import split_data
+from triacontagon.model import Model
+from triacontagon.loop import TrainLoop
+from triacontagon.decode import dedicom_decoder
+from triacontagon.util import common_one_hot_encoding
+
+import os
+import pandas as pd
+from bisect import bisect_left
+import torch
+import sys
+
+
+def index(a, x):
+ i = bisect_left(a, x)
+ if i != len(a) and a[i] == x:
+ return i
+ raise ValueError
+
+
+def load_data(dev):
+ path = '/pstore/data/data_science/ref/decagon'
+
+ df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv'))
+ df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv'))
+ df_mono = pd.read_csv(os.path.join(path, 'bio-decagon-mono.csv'))
+ df_ppi = pd.read_csv(os.path.join(path, 'bio-decagon-ppi.csv'))
+ df_tgtall = pd.read_csv(os.path.join(path, 'bio-decagon-targets-all.csv'))
+ df_tgt = pd.read_csv(os.path.join(path, 'bio-decagon-targets.csv'))
+
+ lst = [ 'df_combo', 'df_effcat', 'df_mono', 'df_ppi', 'df_tgtall', 'df_tgt' ]
+
+ for nam in lst:
+ print(f'len({nam}): {len(locals()[nam])}')
+ print(f'{nam}.columns: {locals()[nam].columns}')
+
+ genes = set()
+ genes = genes.union(df_ppi['Gene 1']).union(df_ppi['Gene 2']) \
+ .union(df_tgtall['Gene']).union(df_tgt['Gene'])
+ genes = sorted(genes)
+ print('len(genes):', len(genes))
+
+ drugs = set()
+ drugs = drugs.union(df_combo['STITCH 1']).union(df_combo['STITCH 2']) \
+ .union(df_mono['STITCH']).union(df_tgtall['STITCH']).union(df_tgt['STITCH'])
+ drugs = sorted(drugs)
+ print('len(drugs):', len(drugs))
+
+ data = Data()
+ data.add_vertex_type('Gene', len(genes))
+ data.add_vertex_type('Drug', len(drugs))
+
+ print('Preparing PPI...')
+ print('Indexing rows...')
+ rows = [index(genes, g) for g in df_ppi['Gene 1']]
+ print('Indexing cols...')
+ cols = [index(genes, g) for g in df_ppi['Gene 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ print('indices.shape:', indices.shape, 'values.shape:', values.shape)
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2,
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ print('adj_mat created')
+ data.add_edge_type('PPI', 0, 0, [ adj_mat ], dedicom_decoder)
+ print('OK')
+
+ print('Preparing Drug-Gene (Target) edges...')
+ rows = [index(drugs, d) for d in df_tgtall['STITCH']]
+ cols = [index(genes, g) for g in df_tgtall['Gene']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)),
+ device=dev)
+ data.add_edge_type('Drug-Gene', 1, 0, [ adj_mat ], dedicom_decoder)
+ data.add_edge_type('Gene-Drug', 0, 1, [ adj_mat.transpose(0, 1) ], dedicom_decoder)
+ print('OK')
+
+ print('Preparing Drug-Drug (Side Effect) edges...')
+ fam = data.add_relation_family('Drug-Drug (Side Effect)', 1, 1, True)
+ print('# of side effects:', len(df_combo), 'unique:', len(df_combo['Polypharmacy Side Effect'].unique()))
+ adjacency_matrices = []
+ side_effect_names = []
+ for eff, df in df_combo.groupby('Polypharmacy Side Effect'):
+ sys.stdout.write('.') # print(eff, '...')
+ sys.stdout.flush()
+ rows = [index(drugs, d) for d in df['STITCH 1']]
+ cols = [index(drugs, d) for d in df['STITCH 2']]
+ indices = list(zip(rows, cols))
+ indices = torch.tensor(indices).transpose(0, 1)
+ values = torch.ones(len(rows))
+ adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)),
+ device=dev)
+ adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2
+ adjacency_matrices.append(adj_mat)
+ side_effect_names.append(df['Polypharmacy Side Effect'])
+ fam.add_edge_type('Drug-Drug', 1, 1, adjacency_matrices, dedicom_decoder)
+ print()
+ print('OK')
+
+ return data
+
+
+def _wrap(obj, method_name):
+ orig_fn = getattr(obj, method_name)
+ def fn(*args, **kwargs):
+ print(f'{method_name}() :: ENTER')
+ res = orig_fn(*args, **kwargs)
+ print(f'{method_name}() :: EXIT')
+ return res
+ setattr(obj, method_name, fn)
+
+
+def main():
+ dev = torch.device('cuda:0')
+ data = load_data(dev)
+
+ train_data, val_data, test_data = split_data(data, (.8, .1, .1))
+
+ n = sum(vt.count for vt in data.vertex_types)
+
+ model = Model(data, [n, 32, 64], keep_prob=.9,
+ conv_activation=torch.sigmoid,
+ dec_activation=torch.sigmoid).to(dev)
+
+ initial_repr = common_one_hot_encoding([ vt.count \
+ for vt in data.vertex_types ], device=dev)
+
+ loop = TrainLoop(model, val_data, test_data,
+ initial_repr, max_epochs=50, batch_size=512,
+ loss=torch.nn.functional.binary_cross_entropy_with_logits,
+ lr=0.001)
+
+ loop.run()
+
+ with open('/pstore/data/data_science/year/2020/adaszews/models/triacontagon/basic_run.pth', 'wb') as f:
+ torch.save(model.state_dict(), f)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..89425dd
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+numpy
+torch
+dataclasses
diff --git a/src/decagon_pytorch/__init__.py b/src/decagon_pytorch/__init__.py
new file mode 100644
index 0000000..8271fcc
--- /dev/null
+++ b/src/decagon_pytorch/__init__.py
@@ -0,0 +1,9 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+from .weights import *
+from .convolve import *
+from .model import *
+from .layer.decode import *
diff --git a/src/decagon_pytorch/batch.py b/src/decagon_pytorch/batch.py
new file mode 100644
index 0000000..aeed492
--- /dev/null
+++ b/src/decagon_pytorch/batch.py
@@ -0,0 +1,49 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import scipy.sparse as sp
+
+
+class Batch(object):
+ def __init__(self, adjacency_matrix):
+ pass
+
+ def get(size):
+ pass
+
+
+def train_test_split(data, train_size=.8):
+ pass
+
+
+class Minibatch(object):
+ def __init__(self, data, node_type_row, node_type_column, size):
+ self.data = data
+ self.adjacency_matrix = data.get_adjacency_matrix(node_type_row, node_type_column)
+ self.size = size
+ self.order = np.random.permutation(adjacency_matrix.nnz)
+ self.count = 0
+
+ def reset(self):
+ self.count = 0
+ self.order = np.random.permutation(adjacency_matrix.nnz)
+
+ def __iter__(self):
+ adj_mat = self.adjacency_matrix
+ size = self.size
+ order = np.random.permutation(adj_mat.nnz)
+ for i in range(0, len(order), size):
+ row = adj_mat.row[i:i + size]
+ col = adj_mat.col[i:i + size]
+ data = adj_mat.data[i:i + size]
+ adj_mat_batch = sp.coo_matrix((data, (row, col)), shape=adj_mat.shape)
+ yield adj_mat_batch
+ degree = self.adjacency_matrix.sum(1)
+
+
+
+ def __len__(self):
+ pass
diff --git a/src/decagon_pytorch/convolve/__init__.py b/src/decagon_pytorch/convolve/__init__.py
new file mode 100644
index 0000000..c0f6305
--- /dev/null
+++ b/src/decagon_pytorch/convolve/__init__.py
@@ -0,0 +1,169 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+"""
+This module implements the basic convolutional blocks of Decagon.
+Just as a quick reminder, the basic convolution formula here is:
+
+y = A * (x * W)
+
+where:
+
+W is a weight matrix
+A is an adjacency matrix
+x is a matrix of latent representations of a particular type of neighbors.
+
+As we have x here twice, a trick is obviously necessary for this to work.
+A must be previously normalized with:
+
+c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
+
+or
+
+c_{r}^{i} = 1/|N_{r}^{i}|
+
+Let's work through this step by step to convince ourselves that the
+formula is correct.
+
+x = [
+ [0, 1, 0, 1],
+ [1, 1, 1, 0],
+ [0, 0, 0, 1]
+]
+
+W = [
+ [0, 1],
+ [1, 0],
+ [0.5, 0.5],
+ [0.25, 0.75]
+]
+
+A = [
+ [0, 1, 0],
+ [1, 0, 1],
+ [0, 1, 0]
+]
+
+so the graph looks like this:
+
+(0) -- (1) -- (2)
+
+and therefore the representations in the next layer should be:
+
+h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W + c_{r}^{0} * h_{0}^{k}
+h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k} +
+ c_{r}^{1} * h_{1}^{k}
+h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W + c_{r}^{2} * h_{2}^{k}
+
+In actual Decagon code we can see that that latter part propagating directly
+the old representation is gone. I will try to do the same for now.
+
+So we have to only take care of:
+
+h_{0}^{k+1} = c_{r}^{0,1} * h_{1}^{k} * W
+h_{1}^{k+1} = c_{r}^{0,1} * h_{0}^{k} * W + c_{r}^{2,1} * h_{2}^{k}
+h_{2}^{k+1} = c_{r}^{2,1} * h_{1}^{k} * W
+
+If A is square the Decagon's EdgeMinibatchIterator preprocesses it as follows:
+
+A = A + eye(len(A))
+rowsum = A.sum(1)
+deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
+A = dot(A, deg_mat_inv_sqrt)
+A = A.transpose()
+A = A.dot(deg_mat_inv_sqrt)
+
+Let's see what gives in our case:
+
+A = A + eye(len(A))
+
+[
+ [1, 1, 0],
+ [1, 1, 1],
+ [0, 1, 1]
+]
+
+rowsum = A.sum(1)
+
+[2, 3, 2]
+
+deg_mat_inv_sqrt = diags(power(rowsum, -0.5))
+
+[
+ [1./sqrt(2), 0, 0],
+ [0, 1./sqrt(3), 0],
+ [0, 0, 1./sqrt(2)]
+]
+
+A = dot(A, deg_mat_inv_sqrt)
+
+[
+ [ 1/sqrt(2), 1/sqrt(3), 0 ],
+ [ 1/sqrt(2), 1/sqrt(3), 1/sqrt(2) ],
+ [ 0, 1/sqrt(3), 1/sqrt(2) ]
+]
+
+A = A.transpose()
+
+[
+ [ 1/sqrt(2), 1/sqrt(2), 0 ],
+ [ 1/sqrt(3), 1/sqrt(3), 1/sqrt(3) ],
+ [ 0, 1/sqrt(2), 1/sqrt(2) ]
+]
+
+A = A.dot(deg_mat_inv_sqrt)
+
+[
+ [ 1/sqrt(2) * 1/sqrt(2), 1/sqrt(2) * 1/sqrt(3), 0 ],
+ [ 1/sqrt(3) * 1/sqrt(2), 1/sqrt(3) * 1/sqrt(3), 1/sqrt(3) * 1/sqrt(2) ],
+ [ 0, 1/sqrt(2) * 1/sqrt(3), 1/sqrt(2) * 1/sqrt(2) ],
+]
+
+thus:
+
+[
+ [0.5 , 0.40824829, 0. ],
+ [0.40824829, 0.33333333, 0.40824829],
+ [0. , 0.40824829, 0.5 ]
+]
+
+This checks out with the 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|) formula.
+
+Then, we get back to the main calculation:
+
+y = x * W
+y = A * y
+
+y = x * W
+
+[
+ [ 1.25, 0.75 ],
+ [ 1.5 , 1.5 ],
+ [ 0.25, 0.75 ]
+]
+
+y = A * y
+
+[
+ 0.5 * [ 1.25, 0.75 ] + 0.40824829 * [ 1.5, 1.5 ],
+ 0.40824829 * [ 1.25, 0.75 ] + 0.33333333 * [ 1.5, 1.5 ] + 0.40824829 * [ 0.25, 0.75 ],
+ 0.40824829 * [ 1.5, 1.5 ] + 0.5 * [ 0.25, 0.75 ]
+]
+
+that is:
+
+[
+ [1.23737243, 0.98737244],
+ [1.11237243, 1.11237243],
+ [0.73737244, 0.98737244]
+].
+
+All checks out nicely, good.
+"""
+
+from .dense import *
+from .sparse import *
+from .universal import *
diff --git a/src/decagon_pytorch/convolve/dense.py b/src/decagon_pytorch/convolve/dense.py
new file mode 100644
index 0000000..37cb700
--- /dev/null
+++ b/src/decagon_pytorch/convolve/dense.py
@@ -0,0 +1,73 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from ..dropout import dropout
+from ..weights import init_glorot
+from typing import List, Callable
+
+
+class DenseGraphConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int,
+ adjacency_matrix: torch.Tensor, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.weight = init_glorot(in_channels, out_channels)
+ self.adjacency_matrix = adjacency_matrix
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.mm(x, self.weight)
+ x = torch.mm(self.adjacency_matrix, x)
+ return x
+
+
+class DenseDropoutGraphConvActivation(torch.nn.Module):
+ def __init__(self, input_dim: int, output_dim: int,
+ adjacency_matrix: torch.Tensor, keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.graph_conv = DenseGraphConv(input_dim, output_dim, adjacency_matrix)
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = dropout(x, keep_prob=self.keep_prob)
+ x = self.graph_conv(x)
+ x = self.activation(x)
+ return x
+
+
+class DenseMultiDGCA(torch.nn.Module):
+ def __init__(self, input_dim: List[int], output_dim: int,
+ adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrices = adjacency_matrices
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.dgca = None
+ self.build()
+
+ def build(self):
+ if len(self.input_dim) != len(self.adjacency_matrices):
+ raise ValueError('input_dim must have the same length as adjacency_matrices')
+ self.dgca = []
+ for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
+ self.dgca.append(DenseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
+
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
+ if not isinstance(x, list):
+ raise ValueError('x must be a list of tensors')
+ out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
+ for i, f in enumerate(self.dgca):
+ out += f(x[i])
+ out = torch.nn.functional.normalize(out, p=2, dim=1)
+ return out
diff --git a/src/decagon_pytorch/convolve/sparse.py b/src/decagon_pytorch/convolve/sparse.py
new file mode 100644
index 0000000..c472007
--- /dev/null
+++ b/src/decagon_pytorch/convolve/sparse.py
@@ -0,0 +1,78 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from ..dropout import dropout_sparse
+from ..weights import init_glorot
+from typing import List, Callable
+
+
+class SparseGraphConv(torch.nn.Module):
+ """Convolution layer for sparse inputs."""
+ def __init__(self, in_channels: int, out_channels: int,
+ adjacency_matrix: torch.Tensor, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.weight = init_glorot(in_channels, out_channels)
+ self.adjacency_matrix = adjacency_matrix
+
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.sparse.mm(x, self.weight)
+ x = torch.sparse.mm(self.adjacency_matrix, x)
+ return x
+
+
+class SparseDropoutGraphConvActivation(torch.nn.Module):
+ def __init__(self, input_dim: int, output_dim: int,
+ adjacency_matrix: torch.Tensor, keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrix = adjacency_matrix
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.sparse_graph_conv = SparseGraphConv(input_dim, output_dim, adjacency_matrix)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = dropout_sparse(x, self.keep_prob)
+ x = self.sparse_graph_conv(x)
+ x = self.activation(x)
+ return x
+
+
+class SparseMultiDGCA(torch.nn.Module):
+ def __init__(self, input_dim: List[int], output_dim: int,
+ adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrices = adjacency_matrices
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.sparse_dgca = None
+ self.build()
+
+ def build(self):
+ if len(self.input_dim) != len(self.adjacency_matrices):
+ raise ValueError('input_dim must have the same length as adjacency_matrices')
+ self.sparse_dgca = []
+ for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
+ self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
+
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
+ if not isinstance(x, list):
+ raise ValueError('x must be a list of tensors')
+ out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
+ for i, f in enumerate(self.sparse_dgca):
+ out += f(x[i])
+ out = torch.nn.functional.normalize(out, p=2, dim=1)
+ return out
diff --git a/src/decagon_pytorch/convolve/universal.py b/src/decagon_pytorch/convolve/universal.py
new file mode 100644
index 0000000..b266448
--- /dev/null
+++ b/src/decagon_pytorch/convolve/universal.py
@@ -0,0 +1,85 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from ..dropout import dropout_sparse, \
+ dropout
+from ..weights import init_glorot
+from typing import List, Callable
+
+
+class GraphConv(torch.nn.Module):
+ """Convolution layer for sparse AND dense inputs."""
+ def __init__(self, in_channels: int, out_channels: int,
+ adjacency_matrix: torch.Tensor, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.weight = init_glorot(in_channels, out_channels)
+ self.adjacency_matrix = adjacency_matrix
+
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.sparse.mm(x, self.weight) \
+ if x.is_sparse \
+ else torch.mm(x, self.weight)
+ x = torch.sparse.mm(self.adjacency_matrix, x) \
+ if self.adjacency_matrix.is_sparse \
+ else torch.mm(self.adjacency_matrix, x)
+ return x
+
+
+class DropoutGraphConvActivation(torch.nn.Module):
+ def __init__(self, input_dim: int, output_dim: int,
+ adjacency_matrix: torch.Tensor, keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrix = adjacency_matrix
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = dropout_sparse(x, self.keep_prob) \
+ if x.is_sparse \
+ else dropout(x, self.keep_prob)
+ x = self.graph_conv(x)
+ x = self.activation(x)
+ return x
+
+
+class MultiDGCA(torch.nn.Module):
+ def __init__(self, input_dim: List[int], output_dim: int,
+ adjacency_matrices: List[torch.Tensor], keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrices = adjacency_matrices
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.dgca = None
+ self.build()
+
+ def build(self):
+ if len(self.input_dim) != len(self.adjacency_matrices):
+ raise ValueError('input_dim must have the same length as adjacency_matrices')
+ self.dgca = []
+ for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
+ self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))
+
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
+ if not isinstance(x, list):
+ raise ValueError('x must be a list of tensors')
+ out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
+ for i, f in enumerate(self.dgca):
+ out += f(x[i])
+ out = torch.nn.functional.normalize(out, p=2, dim=1)
+ return out
diff --git a/src/decagon_pytorch/data/__init__.py b/src/decagon_pytorch/data/__init__.py
new file mode 100644
index 0000000..5820dbb
--- /dev/null
+++ b/src/decagon_pytorch/data/__init__.py
@@ -0,0 +1,2 @@
+from .matrix import *
+from .list import *
diff --git a/src/decagon_pytorch/data/list.py b/src/decagon_pytorch/data/list.py
new file mode 100644
index 0000000..ca022cb
--- /dev/null
+++ b/src/decagon_pytorch/data/list.py
@@ -0,0 +1,68 @@
+from .matrix import NodeType
+import torch
+from collections import defaultdict
+
+
+class AdjListRelationType(object):
+ def __init__(self, name, node_type_row, node_type_column,
+ adjacency_list, adjacency_list_transposed=None):
+
+ #if adjacency_matrix_transposed is not None and \
+ # adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
+ # raise ValueError('adjacency_matrix_transposed has incorrect shape')
+
+ self.name = name
+ self.node_type_row = node_type_row
+ self.node_type_column = node_type_column
+ self.adjacency_list = adjacency_list
+ self.adjacency_list_transposed = adjacency_list_transposed
+
+ def get_adjacency_list(self, node_type_row, node_type_column):
+ if self.node_type_row == node_type_row and \
+ self.node_type_column == node_type_column:
+ return self.adjacency_list
+
+ elif self.node_type_row == node_type_column and \
+ self.node_type_column == node_type_row:
+ if self.adjacency_list_transposed is not None:
+ return self.adjacency_list_transposed
+ else:
+ return torch.index_select(self.adjacency_list, 1,
+ torch.LongTensor([1, 0]))
+
+ else:
+ raise ValueError('Specified row/column types do not correspond to this relation')
+
+
+def _verify_adjacency_list(adjacency_list, node_count_row, node_count_col):
+ assert isinstance(adjacency_list, torch.Tensor)
+ assert len(adjacency_list.shape) == 2
+ assert torch.all(adjacency_list[:, 0] >= 0)
+ assert torch.all(adjacency_list[:, 0] < node_count_row)
+ assert torch.all(adjacency_list[:, 1] >= 0)
+ assert torch.all(adjacency_list[:, 1] < node_count_col)
+
+
+class AdjListData(object):
+ def __init__(self):
+ self.node_types = []
+ self.relation_types = defaultdict(list)
+
+ def add_node_type(self, name, count): # , latent_length):
+ self.node_types.append(NodeType(name, count))
+
+ def add_relation_type(self, name, node_type_row, node_type_col, adjacency_list, adjacency_list_transposed=None):
+ assert node_type_row >= 0 and node_type_row < len(self.node_types)
+ assert node_type_col >= 0 and node_type_col < len(self.node_types)
+
+ node_count_row = self.node_types[node_type_row].count
+ node_count_col = self.node_types[node_type_col].count
+
+ _verify_adjacency_list(adjacency_list, node_count_row, node_count_col)
+ if adjacency_list_transposed is not None:
+ _verify_adjacency_list(adjacency_list_transposed,
+ node_count_col, node_count_row)
+
+ self.relation_types[node_type_row, node_type_col].append(
+ AdjListRelationType(name, node_type_row, node_type_col,
+ adjacency_list, adjacency_list_transposed))
diff --git a/src/decagon_pytorch/data/matrix.py b/src/decagon_pytorch/data/matrix.py
new file mode 100644
index 0000000..cd4b110
--- /dev/null
+++ b/src/decagon_pytorch/data/matrix.py
@@ -0,0 +1,96 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from collections import defaultdict
+from ..weights import init_glorot
+
+
+class NodeType(object):
+ def __init__(self, name, count):
+ self.name = name
+ self.count = count
+
+
+class RelationType(object):
+ def __init__(self, name, node_type_row, node_type_column,
+ adjacency_matrix, adjacency_matrix_transposed):
+
+ if adjacency_matrix_transposed is not None and \
+ adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
+ raise ValueError('adjacency_matrix_transposed has incorrect shape')
+
+ self.name = name
+ self.node_type_row = node_type_row
+ self.node_type_column = node_type_column
+ self.adjacency_matrix = adjacency_matrix
+ self.adjacency_matrix_transposed = adjacency_matrix_transposed
+
+ def get_adjacency_matrix(node_type_row, node_type_column):
+ if self.node_type_row == node_type_row and \
+ self.node_type_column == node_type_column:
+ return self.adjacency_matrix
+
+ elif self.node_type_row == node_type_column and \
+ self.node_type_column == node_type_row:
+ if self.adjacency_matrix_transposed:
+ return self.adjacency_matrix_transposed
+ else:
+ return self.adjacency_matrix.transpose(0, 1)
+
+ else:
+ raise ValueError('Specified row/column types do not correspond to this relation')
+
+
+class Data(object):
+ def __init__(self):
+ self.node_types = []
+ self.relation_types = defaultdict(list)
+
+ def add_node_type(self, name, count): # , latent_length):
+ self.node_types.append(NodeType(name, count))
+
+ def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed=None):
+ n = len(self.node_types)
+ if node_type_row >= n or node_type_column >= n:
+ raise ValueError('Node type index out of bounds, add node type first')
+ key = (node_type_row, node_type_column)
+ if adjacency_matrix is not None and not adjacency_matrix.is_sparse:
+ adjacency_matrix = adjacency_matrix.to_sparse()
+ self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix, adjacency_matrix_transposed))
+
+ def get_adjacency_matrices(self, node_type_row, node_type_column):
+ res = []
+ for (i, j), rels in self.relation_types.items():
+ if node_type_row not in [i, j] and node_type_column not in [i, j]:
+ continue
+ for r in rels:
+ res.append(r.get_adjacency_matrix(node_type_row, node_type_column))
+ return res
+
+ def __repr__(self):
+ n = len(self.node_types)
+ if n == 0:
+ return 'Empty GNN Data'
+ s = ''
+ s += 'GNN Data with:\n'
+ s += '- ' + str(n) + ' node type(s):\n'
+ for nt in self.node_types:
+ s += ' - ' + nt.name + '\n'
+ if len(self.relation_types) == 0:
+ s += '- No relation types\n'
+ return s.strip()
+ n = sum(map(len, self.relation_types))
+ s += '- ' + str(n) + ' relation type(s):\n'
+ for i in range(n):
+ for j in range(n):
+ key = (i, j)
+ if key not in self.relation_types:
+ continue
+ rels = self.relation_types[key]
+ s += ' - ' + self.node_types[i].name + ' -- ' + self.node_types[j].name + ':\n'
+ for r in rels:
+ s += ' - ' + r.name + '\n'
+ return s.strip()
diff --git a/src/decagon_pytorch/data/trainprep.py b/src/decagon_pytorch/data/trainprep.py
new file mode 100644
index 0000000..b2b0a49
--- /dev/null
+++ b/src/decagon_pytorch/data/trainprep.py
@@ -0,0 +1,77 @@
+from .sampling import fixed_unigram_candidate_sampler
+import torch
+
+
+def train_val_test_split_edges(edges, ratios):
+ train_ratio, val_ratio, test_ratio = ratios
+
+ if train_ratio + val_ratio + test_ratio != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * train_ratio)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (train_ratio + val_ratio))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+
+ return edges_train, edges_val, edges_test
+
+
+def prepare_adj_mat(adj_mat, ratios):
+ degrees = adj_mat.sum(0)
+ edges_pos = torch.nonzero(adj_mat)
+
+ neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1],
+ len(edges), degrees, 0.75)
+ edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1)
+
+ edges_pos = (edges_pos_train, edges_pos_val, edges_pos_test) = \
+ train_val_test_split_edges(edges_pos, ratios)
+ edges_neg = (edges_neg_train, edges_neg_val, edges_neg_test) = \
+ train_val_test_split_edges(edges_neg, ratios)
+
+ return edges_pos, edges_neg
+
+
+class PreparedRelation(object):
+ def __init__(self, node_type_row, node_type_column,
+ adj_mat_train, adj_mat_train_trans,
+ edges_pos, edges_neg, edges_pos_trans, edges_neg_trans):
+
+ self.adj_mat_train = adj_mat_train
+ self.adj_mat_train_trans = adj_mat_train_trans
+ self.edges_pos = edges_pos
+ self.edges_neg = edges_neg
+ self.edges_pos_trans = edges_pos_trans
+ self.edges_neg_trans = edges_neg_trans
+
+
+def prepare_relation(r, ratios):
+ adj_mat = r.get_adjacency_matrix(r.node_type_row, r.node_type_column)
+ edges_pos, edges_neg = prepare_adj_mat(adj_mat)
+
+ # adj_mat_train = torch.zeros_like(adj_mat)
+ # adj_mat_train[edges_pos[0][:, 0], edges_pos[0][:, 0]] = 1
+ adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1),
+ values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype))
+
+ if r.node_type_row != r.node_type_col:
+ adj_mat_trans = r.get_adjacency_matrix(r.node_type_col, r.node_type_row)
+ edges_pos_trans, edges_neg_trans = prepare_adj_mat(adj_mat_trans)
+ adj_mat_train_trans = torch.sparse_coo_tensor(indices = edges_pos_trans[0].transpose(0, 1),
+ values=torch.ones(len(edges_pos_trans[0]), dtype=adj_mat_trans.dtype))
+ else:
+ adj_mat_train_trans = adj_mat_trans = \
+ edge_pos_trans = edge_neg_trans = None
+
+ return PreparedRelation(r.node_type_row, r.node_type_column,
+ adj_mat_train, adj_mat_trans_train,
+ edges_pos, edges_neg, edges_pos_trans, edges_neg_trans)
+
+
+def prepare_training(data):
+ for (node_type_row, node_type_column), rels in data.relation_types:
+ for r in rels:
+ prep_relation_edges()
diff --git a/decagon_pytorch/convolve.py b/src/decagon_pytorch/decode/__init__.py
old mode 100755
new mode 100644
similarity index 100%
rename from decagon_pytorch/convolve.py
rename to src/decagon_pytorch/decode/__init__.py
diff --git a/src/decagon_pytorch/decode/cartesian.py b/src/decagon_pytorch/decode/cartesian.py
new file mode 100644
index 0000000..910a8ce
--- /dev/null
+++ b/src/decagon_pytorch/decode/cartesian.py
@@ -0,0 +1,123 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from ..weights import init_glorot
+from ..dropout import dropout
+
+
+class DEDICOMDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.global_interaction = init_glorot(input_dim, input_dim)
+ self.local_variation = [
+ torch.flatten(init_glorot(input_dim, 1)) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ relation = torch.diag(self.local_variation[k])
+
+ product1 = torch.mm(inputs_row, relation)
+ product2 = torch.mm(product1, self.global_interaction)
+ product3 = torch.mm(product2, relation)
+ rec = torch.mm(product3, torch.transpose(inputs_col, 0, 1))
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class DistMultDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.relation = [
+ torch.flatten(init_glorot(input_dim, 1)) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ relation = torch.diag(self.relation[k])
+
+ intermediate_product = torch.mm(inputs_row, relation)
+ rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class BilinearDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.relation = [
+ init_glorot(input_dim, input_dim) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ intermediate_product = torch.mm(inputs_row, self.relation[k])
+ rec = torch.mm(intermediate_product, torch.transpose(inputs_col, 0, 1))
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class InnerProductDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ rec = torch.mm(inputs_row, torch.transpose(inputs_col, 0, 1))
+ outputs.append(self.activation(rec))
+ return outputs
diff --git a/src/decagon_pytorch/decode/pairwise.py b/src/decagon_pytorch/decode/pairwise.py
new file mode 100644
index 0000000..93ff68c
--- /dev/null
+++ b/src/decagon_pytorch/decode/pairwise.py
@@ -0,0 +1,131 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from ..weights import init_glorot
+from ..dropout import dropout
+
+
+class DEDICOMDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.global_interaction = init_glorot(input_dim, input_dim)
+ self.local_variation = [
+ torch.flatten(init_glorot(input_dim, 1)) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ relation = torch.diag(self.local_variation[k])
+
+ product1 = torch.mm(inputs_row, relation)
+ product2 = torch.mm(product1, self.global_interaction)
+ product3 = torch.mm(product2, relation)
+ rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class DistMultDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.relation = [
+ torch.flatten(init_glorot(input_dim, 1)) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ relation = torch.diag(self.relation[k])
+
+ intermediate_product = torch.mm(inputs_row, relation)
+ rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class BilinearDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+ self.relation = [
+ init_glorot(input_dim, input_dim) \
+ for _ in range(num_relation_types)
+ ]
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ intermediate_product = torch.mm(inputs_row, self.relation[k])
+ rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+ outputs.append(self.activation(rec))
+ return outputs
+
+
+class InnerProductDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, drop_prob=0.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.drop_prob = drop_prob
+ self.activation = activation
+
+
+ def forward(self, inputs_row, inputs_col):
+ outputs = []
+ for k in range(self.num_relation_types):
+ inputs_row = dropout(inputs_row, 1.-self.drop_prob)
+ inputs_col = dropout(inputs_col, 1.-self.drop_prob)
+
+ rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+ outputs.append(self.activation(rec))
+ return outputs
diff --git a/src/decagon_pytorch/dropout.py b/src/decagon_pytorch/dropout.py
new file mode 100644
index 0000000..f31b5dc
--- /dev/null
+++ b/src/decagon_pytorch/dropout.py
@@ -0,0 +1,37 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+
+
+def dropout_sparse(x, keep_prob):
+ """Dropout for sparse tensors.
+ """
+ x = x.coalesce()
+ i = x._indices()
+ v = x._values()
+ size = x.size()
+
+ n = keep_prob + torch.rand(len(v))
+ n = torch.floor(n).to(torch.bool)
+ i = i[:,n]
+ v = v[n]
+ x = torch.sparse_coo_tensor(i, v, size=size)
+
+ return x * (1./keep_prob)
+
+
+def dropout(x, keep_prob):
+ """Dropout for dense tensors.
+ """
+ shape = x.shape
+ x = torch.flatten(x)
+ n = keep_prob + torch.rand(len(x))
+ n = (1. - torch.floor(n)).to(torch.bool)
+ x[n] = 0
+ x = torch.reshape(x, shape)
+ # x = torch.nn.functional.dropout(x, p=1.-keep_prob)
+ return x * (1./keep_prob)
diff --git a/src/decagon_pytorch/layer/__init__.py b/src/decagon_pytorch/layer/__init__.py
new file mode 100644
index 0000000..dfb8b70
--- /dev/null
+++ b/src/decagon_pytorch/layer/__init__.py
@@ -0,0 +1,32 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+#
+# This module implements a single layer of the Decagon
+# model. This is going to be already quite complex, as
+# we will be using all the graph convolutional building
+# blocks.
+#
+# h_{i}^(k+1) = ϕ(∑_r ∑_{j∈N{r}^{i}} c_{r}^{ij} * \
+# W_{r}^(k) h_{j}^{k} + c_{r}^{i} h_{i}^(k))
+#
+# N{r}^{i} - set of neighbors of node i under relation r
+# W_{r}^(k) - relation-type specific weight matrix
+# h_{i}^(k) - hidden state of node i in layer k
+# h_{i}^(k)∈R^{d(k)} where d(k) is the dimensionality
+# of the representation in k-th layer
+# Ï• - activation function
+# c_{r}^{ij} - normalization constants
+# c_{r}^{ij} = 1/sqrt(|N_{r}^{i}| |N_{r}^{j}|)
+# c_{r}^{i} - normalization constants
+# c_{r}^{i} = 1/|N_{r}^{i}|
+#
+
+
+from .layer import *
+from .input import *
+from .convolve import *
+from .decode import *
diff --git a/src/decagon_pytorch/layer/convolve.py b/src/decagon_pytorch/layer/convolve.py
new file mode 100644
index 0000000..baebd88
--- /dev/null
+++ b/src/decagon_pytorch/layer/convolve.py
@@ -0,0 +1,80 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .layer import Layer
+import torch
+from ..convolve import DropoutGraphConvActivation
+from ..data import Data
+from typing import List, \
+ Union, \
+ Callable
+from collections import defaultdict
+
+
+class DecagonLayer(Layer):
+ def __init__(self,
+ data: Data,
+ previous_layer: Layer,
+ output_dim: Union[int, List[int]],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ **kwargs):
+ if not isinstance(output_dim, list):
+ output_dim = [ output_dim ] * len(data.node_types)
+ super().__init__(output_dim, is_sparse=False, **kwargs)
+ self.data = data
+ self.previous_layer = previous_layer
+ self.input_dim = previous_layer.output_dim
+ self.keep_prob = keep_prob
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+ self.next_layer_repr = None
+ self.build()
+
+ def build(self):
+ self.next_layer_repr = defaultdict(list)
+
+ for (nt_row, nt_col), relation_types in self.data.relation_types.items():
+ row_convs = []
+ col_convs = []
+
+ for rel in relation_types:
+ conv = DropoutGraphConvActivation(self.input_dim[nt_col],
+ self.output_dim[nt_row], rel.adjacency_matrix,
+ self.keep_prob, self.rel_activation)
+ row_convs.append(conv)
+
+ if nt_row == nt_col:
+ continue
+
+ conv = DropoutGraphConvActivation(self.input_dim[nt_row],
+ self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
+ self.keep_prob, self.rel_activation)
+ col_convs.append(conv)
+
+ self.next_layer_repr[nt_row].append((row_convs, nt_col))
+
+ if nt_row == nt_col:
+ continue
+
+ self.next_layer_repr[nt_col].append((col_convs, nt_row))
+
+ def __call__(self):
+ prev_layer_repr = self.previous_layer()
+ next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
+ print('next_layer_repr:', next_layer_repr)
+ for i in range(len(self.data.node_types)):
+ for convs, neighbor_type in self.next_layer_repr[i]:
+ convs = [ conv(prev_layer_repr[neighbor_type]) \
+ for conv in convs ]
+ convs = sum(convs)
+ convs = torch.nn.functional.normalize(convs, p=2, dim=1)
+ next_layer_repr[i].append(convs)
+ next_layer_repr[i] = sum(next_layer_repr[i])
+ next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
+ print('next_layer_repr:', next_layer_repr)
+ return next_layer_repr
diff --git a/src/decagon_pytorch/layer/decode.py b/src/decagon_pytorch/layer/decode.py
new file mode 100644
index 0000000..f354142
--- /dev/null
+++ b/src/decagon_pytorch/layer/decode.py
@@ -0,0 +1,66 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .layer import Layer
+import torch
+from ..data import Data
+from typing import Type, \
+ List, \
+ Callable, \
+ Union, \
+ Dict, \
+ Tuple
+from ..decode.cartesian import DEDICOMDecoder
+
+
+class DecodeLayer(torch.nn.Module):
+ def __init__(self,
+ data: Data,
+ last_layer: Layer,
+ keep_prob: float = 1.,
+ activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
+ decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, **kwargs) -> None:
+
+ super().__init__(**kwargs)
+ self.data = data
+ self.last_layer = last_layer
+ self.keep_prob = keep_prob
+ self.activation = activation
+ assert all([a == last_layer.output_dim[0] \
+ for a in last_layer.output_dim])
+ self.input_dim = last_layer.output_dim[0]
+ self.output_dim = 1
+ self.decoder_class = decoder_class
+ self.decoders = None
+ self.build()
+
+ def build(self) -> None:
+ self.decoders = {}
+ for (node_type_row, node_type_col), rels in self.data.relation_types.items():
+ key = (node_type_row, node_type_col)
+ if isinstance(self.decoder_class, dict):
+ if key in self.decoder_class:
+ decoder_class = self.decoder_class[key]
+ else:
+ raise KeyError('Decoder not specified for edge type: %d -- %d' % key)
+ else:
+ decoder_class = self.decoder_class
+
+ self.decoders[key] = decoder_class(self.input_dim,
+ num_relation_types = len(rels),
+ drop_prob = 1. - self.keep_prob,
+ activation = self.activation)
+
+
+ def forward(self, last_layer_repr: List[torch.Tensor]):
+ res = {}
+ for (node_type_row, node_type_col), rel in self.data.relation_types.items():
+ key = (node_type_row, node_type_col)
+ inputs_row = last_layer_repr[node_type_row]
+ inputs_col = last_layer_repr[node_type_col]
+ pred_adj_matrices = self.decoders[key](inputs_row, inputs_col)
+ res[node_type_row, node_type_col] = pred_adj_matrices
+ return res
diff --git a/src/decagon_pytorch/layer/input.py b/src/decagon_pytorch/layer/input.py
new file mode 100644
index 0000000..1b3a3f1
--- /dev/null
+++ b/src/decagon_pytorch/layer/input.py
@@ -0,0 +1,71 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .layer import Layer
+import torch
+from typing import Union, \
+ List
+from ..data import Data
+
+
+class InputLayer(Layer):
+ def __init__(self, data: Data, output_dim: Union[int, List[int]]= None, **kwargs) -> None:
+ output_dim = output_dim or \
+ list(map(lambda a: a.count, data.node_types))
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim,] * len(data.node_types)
+
+ super().__init__(output_dim, is_sparse=False, **kwargs)
+ self.data = data
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = []
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.rand(nt.count, self.output_dim[i])
+ reps = torch.nn.Parameter(reps)
+ self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'GNN input layer with output_dim: %s\n' % self.output_dim
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
+
+
+class OneHotInputLayer(Layer):
+ def __init__(self, data: Data, **kwargs) -> None:
+ output_dim = [ a.count for a in data.node_types ]
+ super().__init__(output_dim, is_sparse=True, **kwargs)
+ self.data = data
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = []
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.eye(nt.count).to_sparse()
+ reps = torch.nn.Parameter(reps)
+ self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'One-hot GNN input layer\n'
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
diff --git a/src/decagon_pytorch/layer/layer.py b/src/decagon_pytorch/layer/layer.py
new file mode 100644
index 0000000..09a04d4
--- /dev/null
+++ b/src/decagon_pytorch/layer/layer.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from typing import List, \
+ Union
+
+
+class Layer(torch.nn.Module):
+ def __init__(self,
+ output_dim: Union[int, List[int]],
+ is_sparse: bool,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.is_sparse = is_sparse
diff --git a/src/decagon_pytorch/model.py b/src/decagon_pytorch/model.py
new file mode 100644
index 0000000..f99dc11
--- /dev/null
+++ b/src/decagon_pytorch/model.py
@@ -0,0 +1,12 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+class Model(object):
+ def __init__(self, data):
+ self.data = data
+
+ def build(self):
+ pass
diff --git a/src/decagon_pytorch/normalize.py b/src/decagon_pytorch/normalize.py
new file mode 100644
index 0000000..4af86cd
--- /dev/null
+++ b/src/decagon_pytorch/normalize.py
@@ -0,0 +1,56 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import numpy as np
+import scipy.sparse as sp
+
+
+def sparse_to_tuple(sparse_mx):
+ if not sp.isspmatrix_coo(sparse_mx):
+ sparse_mx = sparse_mx.tocoo()
+ coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
+ values = sparse_mx.data
+ shape = sparse_mx.shape
+ return coords, values, shape
+
+
+def normalize_adjacency_matrix(adj):
+ adj = sp.coo_matrix(adj)
+
+ if adj.shape[0] == adj.shape[1]:
+ adj_ = adj + sp.eye(adj.shape[0])
+ rowsum = np.array(adj_.sum(1))
+ degree_mat_inv_sqrt = np.power(rowsum, -0.5).flatten()
+ degree_mat_inv_sqrt = sp.diags(degree_mat_inv_sqrt)
+ adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
+ else:
+ rowsum = np.array(adj.sum(1))
+ colsum = np.array(adj.sum(0))
+ rowdegree_mat_inv = sp.diags(np.nan_to_num(np.power(rowsum, -0.5)).flatten())
+ coldegree_mat_inv = sp.diags(np.nan_to_num(np.power(colsum, -0.5)).flatten())
+ adj_normalized = rowdegree_mat_inv.dot(adj).dot(coldegree_mat_inv).tocoo()
+ return sparse_to_tuple(adj_normalized)
+
+
+def norm_adj_mat_one_node_type(adj):
+ adj = sp.coo_matrix(adj)
+ assert adj.shape[0] == adj.shape[1]
+ adj_ = adj + sp.eye(adj.shape[0])
+ rowsum = np.array(adj_.sum(1))
+ degree_mat_inv_sqrt = np.power(rowsum, -0.5).flatten()
+ degree_mat_inv_sqrt = sp.diags(degree_mat_inv_sqrt)
+ adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
+ return adj_normalized
+
+
+def norm_adj_mat_two_node_types(adj):
+ adj = sp.coo_matrix(adj)
+ rowsum = np.array(adj.sum(1))
+ colsum = np.array(adj.sum(0))
+ rowdegree_mat_inv = sp.diags(np.nan_to_num(np.power(rowsum, -0.5)).flatten())
+ coldegree_mat_inv = sp.diags(np.nan_to_num(np.power(colsum, -0.5)).flatten())
+ adj_normalized = rowdegree_mat_inv.dot(adj).dot(coldegree_mat_inv).tocoo()
+ return adj_normalized
diff --git a/src/decagon_pytorch/sampling.py b/src/decagon_pytorch/sampling.py
new file mode 100644
index 0000000..c2357a3
--- /dev/null
+++ b/src/decagon_pytorch/sampling.py
@@ -0,0 +1,36 @@
+import numpy as np
+import torch
+import torch.utils.data
+from typing import List, \
+ Union
+
+
+def fixed_unigram_candidate_sampler(
+ true_classes: Union[np.array, torch.Tensor],
+ num_samples: int,
+ unigrams: List[Union[int, float]],
+ distortion: float = 1.):
+
+ if isinstance(true_classes, torch.Tensor):
+ true_classes = true_classes.detach().cpu().numpy()
+ if true_classes.shape[0] != num_samples:
+ raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
+ unigrams = np.array(unigrams)
+ if distortion != 1.:
+ unigrams = unigrams.astype(np.float64) ** distortion
+ # print('unigrams:', unigrams)
+ indices = np.arange(num_samples)
+ result = np.zeros(num_samples, dtype=np.int64)
+ while len(indices) > 0:
+ # print('len(indices):', len(indices))
+ sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
+ candidates = np.array(list(sampler))
+ candidates = np.reshape(candidates, (len(indices), 1))
+ # print('candidates:', candidates)
+ # print('true_classes:', true_classes[indices, :])
+ result[indices] = candidates.T
+ mask = (candidates == true_classes[indices, :])
+ mask = mask.sum(1).astype(np.bool)
+ # print('mask:', mask)
+ indices = indices[mask]
+ return result
diff --git a/src/decagon_pytorch/splits.py b/src/decagon_pytorch/splits.py
new file mode 100644
index 0000000..9d219b6
--- /dev/null
+++ b/src/decagon_pytorch/splits.py
@@ -0,0 +1,60 @@
+import torch
+from .data import Data, \
+ AdjListData
+
+
+def train_val_test_split_adj_mat(adj_mat, train_ratio, val_ratio, test_ratio,
+ return_edges=False):
+
+ if train_ratio + val_ratio + test_ratio != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ edges = torch.nonzero(adj_mat)
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * train_ratio)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (train_ratio + val_ratio))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+
+ adj_mat_train = torch.zeros_like(adj_mat)
+ adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
+
+ adj_mat_val = torch.zeros_like(adj_mat)
+ adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
+
+ adj_mat_test = torch.zeros_like(adj_mat)
+ adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
+
+ res = (adj_mat_train, adj_mat_val, adj_mat_test)
+ if return_edges:
+ res += (edges_train, edges_val, edges_test)
+
+ return res
+
+
+def train_val_test_split_edges(adj_mat, train_ratio, val_ratio, test_ratio):
+ if train_ratio + val_ratio + test_ratio != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ edges = torch.nonzero(adj_mat)
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * train_ratio)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (train_ratio + val_ratio))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+
+ adj_mat_train = torch.zeros_like(adj_mat)
+ adj_mat_train[edges_train[:, 0], edges_train[:, 1]] = 1
+
+ adj_mat_val = torch.zeros_like(adj_mat)
+ adj_mat_val[edges_val[:, 0], edges_val[:, 1]] = 1
+
+ adj_mat_test = torch.zeros_like(adj_mat)
+ adj_mat_test[edges_test[:, 0], edges_test[:, 1]] = 1
+
+ return adj_mat_train, adj_mat_val, adj_mat_test, \
+ edges_train, edges_val, edges_test
diff --git a/src/decagon_pytorch/weights.py b/src/decagon_pytorch/weights.py
new file mode 100644
index 0000000..2dcb7b4
--- /dev/null
+++ b/src/decagon_pytorch/weights.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+import numpy as np
+
+
+def init_glorot(in_channels, out_channels, dtype=torch.float32):
+ """Create a weight variable with Glorot & Bengio (AISTATS 2010)
+ initialization.
+ """
+ init_range = np.sqrt(6.0 / (in_channels + out_channels))
+ initial = -init_range + 2 * init_range * \
+ torch.rand(( in_channels, out_channels ), dtype=dtype)
+ initial = initial.requires_grad_(True)
+ return initial
diff --git a/src/icosagon/__init__.py b/src/icosagon/__init__.py
new file mode 100644
index 0000000..78237bd
--- /dev/null
+++ b/src/icosagon/__init__.py
@@ -0,0 +1,7 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .data import Data
diff --git a/src/icosagon/batch.py b/src/icosagon/batch.py
new file mode 100644
index 0000000..e395f3f
--- /dev/null
+++ b/src/icosagon/batch.py
@@ -0,0 +1,102 @@
+from .declayer import Predictions
+import torch
+from dataclasses import dataclass
+from .trainprep import PreparedData
+from typing import Tuple
+
+
+@dataclass
+class FlatPredictions(object):
+ predictions: torch.Tensor
+ truth: torch.Tensor
+ part_type: str
+
+
+def flatten_predictions(pred: Predictions, part_type: str = 'train'):
+ if not isinstance(pred, Predictions):
+ raise TypeError('pred must be an instance of Predictions')
+
+ if part_type not in ['train', 'val', 'test']:
+ raise ValueError('part_type must be set to train, val or test')
+
+ edge_types = [('edges_pos', 1), ('edges_neg', 0),
+ ('edges_back_pos', 1), ('edges_back_neg', 0)]
+
+ input = []
+ target = []
+
+ for fam in pred.relation_families:
+ for rel in fam.relation_types:
+ for (et, tgt) in edge_types:
+ edge_pred = getattr(getattr(rel, et), part_type)
+ input.append(edge_pred)
+ target.append(torch.ones_like(edge_pred) * tgt)
+
+ input = torch.cat(input)
+ target = torch.cat(target)
+
+ return FlatPredictions(input, target, part_type)
+
+
+@dataclass
+class BatchIndices(object):
+ indices: torch.Tensor
+ part_type: str
+
+
+def gather_batch_indices(pred: FlatPredictions,
+ indices: BatchIndices) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ if not isinstance(pred, FlatPredictions):
+ raise TypeError('pred must be an instance of FlatPredictions')
+
+ if not isinstance(indices, BatchIndices):
+ raise TypeError('indices must be an instance of BatchIndices')
+
+ if pred.part_type != indices.part_type:
+ raise ValueError('part_type must be the same in pred and indices')
+
+ return (pred.predictions[indices.indices],
+ pred.truth[indices.indices])
+
+
+class PredictionsBatch(object):
+ def __init__(self, prep_d: PreparedData, part_type: str = 'train',
+ batch_size: int = 100, shuffle: bool = False,
+ generator: torch.Generator = None) -> None:
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instance of PreparedData')
+
+ if part_type not in ['train', 'val', 'test']:
+ raise ValueError('part_type must be set to train, val or test')
+
+ batch_size = int(batch_size)
+
+ shuffle = bool(shuffle)
+
+ if generator is not None and not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ self.prep_d = prep_d
+ self.part_type = part_type
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.generator = generator or torch.default_generator
+
+ count = 0
+ for fam in prep_d.relation_families:
+ for rel in fam.relation_types:
+ for et in ['edges_pos', 'edges_neg',
+ 'edges_back_pos', 'edges_back_neg']:
+ count += len(getattr(getattr(rel, et), part_type))
+ self.total_edge_count = count
+
+ def __iter__(self):
+ values = torch.arange(self.total_edge_count)
+ if self.shuffle:
+ perm = torch.randperm(len(values))
+ values = values[perm]
+
+ for i in range(0, len(values), self.batch_size):
+ yield BatchIndices(values[i:i+self.batch_size], self.part_type)
diff --git a/src/icosagon/bulkdec.py b/src/icosagon/bulkdec.py
new file mode 100644
index 0000000..cbf615e
--- /dev/null
+++ b/src/icosagon/bulkdec.py
@@ -0,0 +1,119 @@
+from icosagon.data import Data
+from icosagon.trainprep import PreparedData
+from icosagon.decode import DEDICOMDecoder, \
+ DistMultDecoder, \
+ BilinearDecoder, \
+ InnerProductDecoder
+from icosagon.dropout import dropout
+import torch
+from typing import List, \
+ Callable, \
+ Union
+
+
+'''
+Let's say that I have dense latent representations row and col.
+Then let's take relation matrix rel in a list of relations REL.
+A single computation currenty looks like this:
+(((row * rel) * glob) * rel) * col
+
+Shouldn't then this basically work:
+
+prod1 = torch.matmul(row, REL)
+prod2 = torch.matmul(prod1, glob)
+prod3 = torch.matmul(prod2, REL)
+res = torch.matmul(prod3, col)
+res = activation(res)
+
+res should then have shape: (num_relations, num_rows, num_columns)
+'''
+
+
+def convert_decoder(dec):
+ if isinstance(dec, DEDICOMDecoder):
+ global_interaction = dec.global_interaction
+ local_variation = map(torch.diag, dec.local_variation)
+ elif isinstance(dec, DistMultDecoder):
+ global_interaction = torch.eye(dec.input_dim, dec.input_dim)
+ local_variation = map(torch.diag, dec.relation)
+ elif isinstance(dec, BilinearDecoder):
+ global_interaction = torch.eye(dec.input_dim, dec.input_dim)
+ local_variation = dec.relation
+ elif isinstance(dec, InnerProductDecoder):
+ global_interaction = torch.eye(dec.input_dim, dec.input_dim)
+ local_variation = torch.eye(dec.input_dim, dec.input_dim)
+ local_variation = [ local_variation ] * dec.num_relation_types
+ else:
+ raise TypeError('Unknown decoder type in convert_decoder()')
+
+ if not isinstance(local_variation, torch.Tensor):
+ local_variation = map(lambda a: a.view(1, *a.shape), local_variation)
+ local_variation = torch.cat(list(local_variation))
+
+ return (global_interaction, local_variation)
+
+
+class BulkDecodeLayer(torch.nn.Module):
+ def __init__(self,
+ input_dim: List[int],
+ data: Union[Data, PreparedData],
+ keep_prob: float = 1.,
+ activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ self._check_params(input_dim, data)
+
+ self.input_dim = input_dim[0]
+ self.data = data
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.decoders = None
+ self.global_interaction = None
+ self.local_variation = None
+ self.build()
+
+ def build(self) -> None:
+ self.decoders = torch.nn.ModuleList()
+ self.global_interaction = torch.nn.ParameterList()
+ self.local_variation = torch.nn.ParameterList()
+ for fam in self.data.relation_families:
+ dec = fam.decoder_class(self.input_dim,
+ len(fam.relation_types),
+ self.keep_prob,
+ self.activation)
+ self.decoders.append(dec)
+ global_interaction, local_variation = convert_decoder(dec)
+ self.global_interaction.append(torch.nn.Parameter(global_interaction))
+ self.local_variation.append(torch.nn.Parameter(local_variation))
+
+ def forward(self, last_layer_repr: List[torch.Tensor]) -> List[torch.Tensor]:
+ res = []
+ for i, fam in enumerate(self.data.relation_families):
+ repr_row = last_layer_repr[fam.node_type_row]
+ repr_column = last_layer_repr[fam.node_type_column]
+ repr_row = dropout(repr_row, keep_prob=self.keep_prob)
+ repr_column = dropout(repr_column, keep_prob=self.keep_prob)
+ prod_1 = torch.matmul(repr_row, self.local_variation[i])
+ print(f'local_variation[{i}].shape: {self.local_variation[i].shape}')
+ prod_2 = torch.matmul(prod_1, self.global_interaction[i])
+ prod_3 = torch.matmul(prod_2, self.local_variation[i])
+ pred = torch.matmul(prod_3, repr_column.transpose(0, 1))
+ res.append(pred)
+ return res
+
+ @staticmethod
+ def _check_params(input_dim, data):
+ if not isinstance(input_dim, list):
+ raise TypeError('input_dim must be a list')
+
+ if len(input_dim) != len(data.node_types):
+ raise ValueError('input_dim must have length equal to num_node_types')
+
+ if not all([ a == input_dim[0] for a in input_dim ]):
+ raise ValueError('All elements of input_dim must have the same value')
+
+ if not isinstance(data, Data) and not isinstance(data, PreparedData):
+ raise TypeError('data must be an instance of Data or PreparedData')
diff --git a/src/icosagon/compile.py b/src/icosagon/compile.py
new file mode 100644
index 0000000..a16c29d
--- /dev/null
+++ b/src/icosagon/compile.py
@@ -0,0 +1,28 @@
+#
+# The goal of this module is to make Icosagon more efficient.
+# It takes the nice Icosagon model architecture and tries to
+# formulate it in terms of batch matrix multiplications instead
+# of using Python for loops.
+#
+
+from .weights import init_glorot
+from .input
+import torch
+
+
+class EncodeLayer(object):
+ def __init__(self, num_relation_types, input_dim, output_dim):
+ weights = [ init_glorot(input_dim, output_dim) \
+ for _ in range(num_relation_types) ]
+ weights = torch.cat(weights)
+
+
+class Compiler(object):
+ def __init__(self, data: Data, layer_dimensions: List[int] = [32, 64]) -> None:
+ self.data = data
+ self.layer_dimensions = layer_dimensions
+ self.build()
+
+ def build(self) -> None:
+ for fam in data.relation_families:
+ init_glorot(in_channels, out_channels)
diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py
new file mode 100644
index 0000000..3c5b603
--- /dev/null
+++ b/src/icosagon/convlayer.py
@@ -0,0 +1,126 @@
+import torch
+from .convolve import DropoutGraphConvActivation
+from .data import Data
+from .trainprep import PreparedData
+from typing import List, \
+ Union, \
+ Callable
+from collections import defaultdict
+from dataclasses import dataclass
+import time
+
+
+class Convolutions(torch.nn.Module):
+ node_type_column: int
+ convolutions: torch.nn.ModuleList # [DropoutGraphConvActivation]
+
+ def __init__(self, node_type_column: int,
+ convolutions: torch.nn.ModuleList, **kwargs):
+
+ super().__init__(**kwargs)
+ self.node_type_column = node_type_column
+ self.convolutions = convolutions
+
+
+class DecagonLayer(torch.nn.Module):
+ def __init__(self,
+ input_dim: List[int],
+ output_dim: List[int],
+ data: Union[Data, PreparedData],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ **kwargs):
+
+ super().__init__(**kwargs)
+
+ if not isinstance(input_dim, list):
+ raise ValueError('input_dim must be a list')
+
+ if not output_dim:
+ raise ValueError('output_dim must be specified')
+
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim] * len(data.node_types)
+
+ if not isinstance(data, Data) and not isinstance(data, PreparedData):
+ raise ValueError('data must be of type Data or PreparedData')
+
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.data = data
+ self.keep_prob = float(keep_prob)
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+
+ self.is_sparse = False
+ self.next_layer_repr = None
+ self.build()
+
+ def build_fam_one_node_type(self, fam):
+ convolutions = torch.nn.ModuleList()
+
+ for r in fam.relation_types:
+ conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row], r.adjacency_matrix,
+ self.keep_prob, self.rel_activation)
+ convolutions.append(conv)
+
+ self.next_layer_repr[fam.node_type_row].append(
+ Convolutions(fam.node_type_column, convolutions))
+
+ def build_fam_two_node_types(self, fam) -> None:
+ convolutions_row = torch.nn.ModuleList()
+ convolutions_column = torch.nn.ModuleList()
+
+ for r in fam.relation_types:
+ if r.adjacency_matrix is not None:
+ conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row], r.adjacency_matrix,
+ self.keep_prob, self.rel_activation)
+ convolutions_row.append(conv)
+
+ if r.adjacency_matrix_backward is not None:
+ conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row],
+ self.output_dim[fam.node_type_column], r.adjacency_matrix_backward,
+ self.keep_prob, self.rel_activation)
+ convolutions_column.append(conv)
+
+ self.next_layer_repr[fam.node_type_row].append(
+ Convolutions(fam.node_type_column, convolutions_row))
+
+ self.next_layer_repr[fam.node_type_column].append(
+ Convolutions(fam.node_type_row, convolutions_column))
+
+ def build_family(self, fam) -> None:
+ if fam.node_type_row == fam.node_type_column:
+ self.build_fam_one_node_type(fam)
+ else:
+ self.build_fam_two_node_types(fam)
+
+ def build(self):
+ self.next_layer_repr = torch.nn.ModuleList([
+ torch.nn.ModuleList() for _ in range(len(self.data.node_types)) ])
+ for fam in self.data.relation_families:
+ self.build_family(fam)
+
+ def __call__(self, prev_layer_repr):
+ t = time.time()
+ next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
+ n = len(self.data.node_types)
+
+ for node_type_row in range(n):
+ for convolutions in self.next_layer_repr[node_type_row]:
+ repr_ = [ conv(prev_layer_repr[convolutions.node_type_column]) \
+ for conv in convolutions.convolutions ]
+ repr_ = sum(repr_)
+ repr_ = torch.nn.functional.normalize(repr_, p=2, dim=1)
+ next_layer_repr[node_type_row].append(repr_)
+ if len(next_layer_repr[node_type_row]) == 0:
+ next_layer_repr[node_type_row] = torch.zeros(self.output_dim[node_type_row])
+ else:
+ next_layer_repr[node_type_row] = sum(next_layer_repr[node_type_row])
+ next_layer_repr[node_type_row] = self.layer_activation(next_layer_repr[node_type_row])
+
+ # print('DecagonLayer.forward() took', time.time() - t)
+ return next_layer_repr
diff --git a/src/icosagon/convolve.py b/src/icosagon/convolve.py
new file mode 100644
index 0000000..09b6eca
--- /dev/null
+++ b/src/icosagon/convolve.py
@@ -0,0 +1,58 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .dropout import dropout
+from .weights import init_glorot
+from typing import List, Callable
+import pdb
+
+
+class GraphConv(torch.nn.Module):
+ def __init__(self, in_channels: int, out_channels: int,
+ adjacency_matrix: torch.Tensor, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
+ self.adjacency_matrix = adjacency_matrix
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = torch.sparse.mm(x, self.weight) \
+ if x.is_sparse \
+ else torch.mm(x, self.weight)
+ x = torch.sparse.mm(self.adjacency_matrix, x) \
+ if self.adjacency_matrix.is_sparse \
+ else torch.mm(self.adjacency_matrix, x)
+ return x
+
+
+class DropoutGraphConvActivation(torch.nn.Module):
+ def __init__(self, input_dim: int, output_dim: int,
+ adjacency_matrix: torch.Tensor, keep_prob: float=1.,
+ activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.adjacency_matrix = adjacency_matrix
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # pdb.set_trace()
+ x = dropout(x, self.keep_prob)
+ x = self.graph_conv(x)
+ x = self.activation(x)
+ return x
+
+ def clone(self, adjacency_matrix) -> 'DropoutGraphConvActivation':
+ res = DropoutGraphConvActivation(self.input_dim,
+ self.output_dim, adjacency_matrix, self.keep_prob,
+ self.activation)
+ res.graph_conv.weight = self.graph_conv.weight
+ return res
diff --git a/src/icosagon/data.py b/src/icosagon/data.py
new file mode 100644
index 0000000..4505adf
--- /dev/null
+++ b/src/icosagon/data.py
@@ -0,0 +1,209 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from collections import defaultdict
+from dataclasses import dataclass, field
+import torch
+from typing import List, \
+ Dict, \
+ Tuple, \
+ Any, \
+ Type
+from .decode import DEDICOMDecoder, \
+ BilinearDecoder
+import numpy as np
+
+
+def _equal(x: torch.Tensor, y: torch.Tensor):
+ if x.is_sparse ^ y.is_sparse:
+ raise ValueError('Cannot mix sparse and dense tensors')
+
+ if not x.is_sparse:
+ return (x == y)
+
+ return ((x - y).coalesce().values() == 0)
+
+
+@dataclass
+class NodeType(object):
+ name: str
+ count: int
+
+
+@dataclass
+class RelationTypeBase(object):
+ name: str
+ node_type_row: int
+ node_type_column: int
+ adjacency_matrix: torch.Tensor
+ adjacency_matrix_backward: torch.Tensor
+
+
+@dataclass
+class RelationType(RelationTypeBase):
+ pass
+
+
+@dataclass
+class RelationFamilyBase(object):
+ data: 'Data'
+ name: str
+ node_type_row: int
+ node_type_column: int
+ is_symmetric: bool
+ decoder_class: Type
+
+
+@dataclass
+class RelationFamily(RelationFamilyBase):
+ relation_types: List[RelationType] = None
+
+ def __post_init__(self) -> None:
+ if not self.is_symmetric and \
+ self.decoder_class != DEDICOMDecoder and \
+ self.decoder_class != BilinearDecoder:
+ raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
+
+ self.relation_types = []
+
+ def add_relation_type(self,
+ name: str, adjacency_matrix: torch.Tensor,
+ adjacency_matrix_backward: torch.Tensor = None) -> None:
+
+ name = str(name)
+ node_type_row = self.node_type_row
+ node_type_column = self.node_type_column
+
+ if adjacency_matrix is None and adjacency_matrix_backward is None:
+ raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
+
+ if adjacency_matrix is not None and \
+ not isinstance(adjacency_matrix, torch.Tensor):
+ raise ValueError('adjacency_matrix must be a torch.Tensor')
+
+ if adjacency_matrix_backward is not None \
+ and not isinstance(adjacency_matrix_backward, torch.Tensor):
+ raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
+
+ if adjacency_matrix is not None and \
+ adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
+ self.data.node_types[node_type_column].count):
+ raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
+
+ if adjacency_matrix_backward is not None and \
+ adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
+ self.data.node_types[node_type_row].count):
+ raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
+
+ if node_type_row == node_type_column and \
+ adjacency_matrix_backward is not None:
+ raise ValueError('Relation between nodes of the same type must be expressed using a single matrix')
+
+ if self.is_symmetric and adjacency_matrix_backward is not None:
+ raise ValueError('Cannot use a custom adjacency_matrix_backward in a symmetric relation family')
+
+ if self.is_symmetric and node_type_row == node_type_column and \
+ not torch.all(_equal(adjacency_matrix,
+ adjacency_matrix.transpose(0, 1))):
+ raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
+
+ if not self.is_symmetric and node_type_row != node_type_column and \
+ adjacency_matrix_backward is None:
+ raise ValueError('Relation is asymmetric but adjacency_matrix_backward is None')
+
+ if self.is_symmetric and node_type_row != node_type_column:
+ adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
+
+ self.relation_types.append(RelationType(name,
+ node_type_row, node_type_column,
+ adjacency_matrix, adjacency_matrix_backward))
+
+ def node_name(self, index):
+ return self.data.node_types[index].name
+
+ def __repr__(self):
+ s = 'Relation family %s' % self.name
+
+ for r in self.relation_types:
+ s += '\n - %s%s' % (r.name, ' (two-way)' \
+ if (r.adjacency_matrix is not None \
+ and r.adjacency_matrix_backward is not None) \
+ or self.node_type_row == self.node_type_column \
+ else '%s <- %s' % (self.node_name(self.node_type_row),
+ self.node_name(self.node_type_column)))
+
+ return s
+
+ def repr_indented(self):
+ s = ' - %s' % self.name
+
+ for r in self.relation_types:
+ s += '\n - %s%s' % (r.name, ' (two-way)' \
+ if (r.adjacency_matrix is not None \
+ and r.adjacency_matrix_backward is not None) \
+ or self.node_type_row == self.node_type_column \
+ else '%s <- %s' % (self.node_name(self.node_type_row),
+ self.node_name(self.node_type_column)))
+
+ return s
+
+
+class Data(object):
+ node_types: List[NodeType]
+ relation_families: List[RelationFamily]
+
+ def __init__(self) -> None:
+ self.node_types = []
+ self.relation_families = []
+
+ def add_node_type(self, name: str, count: int) -> None:
+ name = str(name)
+ count = int(count)
+ if not name:
+ raise ValueError('You must provide a non-empty node type name')
+ if count <= 0:
+ raise ValueError('You must provide a positive node count')
+ self.node_types.append(NodeType(name, count))
+
+ def add_relation_family(self, name: str, node_type_row: int,
+ node_type_column: int, is_symmetric: bool,
+ decoder_class: Type = DEDICOMDecoder):
+
+ name = str(name)
+ node_type_row = int(node_type_row)
+ node_type_column = int(node_type_column)
+ is_symmetric = bool(is_symmetric)
+
+ if node_type_row < 0 or node_type_row >= len(self.node_types):
+ raise ValueError('node_type_row outside of the valid range of node types')
+
+ if node_type_column < 0 or node_type_column >= len(self.node_types):
+ raise ValueError('node_type_column outside of the valid range of node types')
+
+ fam = RelationFamily(self, name, node_type_row, node_type_column,
+ is_symmetric, decoder_class)
+ self.relation_families.append(fam)
+
+ return fam
+
+ def __repr__(self):
+ n = len(self.node_types)
+ if n == 0:
+ return 'Empty Icosagon Data'
+ s = ''
+ s += 'Icosagon Data with:\n'
+ s += '- ' + str(n) + ' node type(s):\n'
+ for nt in self.node_types:
+ s += ' - ' + nt.name + '\n'
+ if len(self.relation_families) == 0:
+ s += '- No relation families\n'
+ return s.strip()
+
+ s += '- %d relation families:\n' % len(self.relation_families)
+ for fam in self.relation_families:
+ s += fam.repr_indented() + '\n'
+
+ return s.strip()
diff --git a/src/icosagon/databatch.py b/src/icosagon/databatch.py
new file mode 100644
index 0000000..3602d6d
--- /dev/null
+++ b/src/icosagon/databatch.py
@@ -0,0 +1,117 @@
+from icosagon.trainprep import PreparedData, \
+ PreparedRelationFamily, \
+ PreparedRelationType, \
+ _empty_edge_list_tvt
+import torch
+import random
+
+
+class BatchedData(PreparedData):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+class BatchedDataPointer(object):
+ def __init__(self, batched_data):
+ self.batched_data = batched_data
+
+
+def batched_data_skeleton(data: PreparedData) -> BatchedData:
+ if not isinstance(data, PreparedData):
+ raise TypeError('data must be an instance of PreparedData')
+
+ fam_skels = []
+ for fam in data.relation_families:
+ rel_types_skel = []
+ for rel in fam.relation_types:
+ rel_skel = PreparedRelationType(rel.name,
+ rel.node_type_row, rel.node_type_column,
+ rel.adjacency_matrix, rel.adjacency_matrix_backward,
+ _empty_edge_list_tvt(), _empty_edge_list_tvt(),
+ _empty_edge_list_tvt(), _empty_edge_list_tvt())
+ rel_types_skel.append(rel_skel)
+ fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
+ fam.node_type_row, fam.node_type_column,
+ fam.is_symmetric, fam.decoder_class,
+ rel_types_skel))
+ return BatchedData(data.node_types, fam_skels)
+
+
+class DataBatcher(object):
+ def __init__(self, data: PreparedData, batch_size: int,
+ shuffle: bool = True) -> None:
+ self._check_params(data, batch_size)
+
+ self.data = data
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+
+ # def batched_data_iter(self, fam_idx: int, rel_idx: int,
+ # part_type: str) -> BatchedData:
+ #
+ # rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
+ #
+ # edges = getattr(rel.edges_pos, part_type)
+ # for m in range(0, len(edges), self.batch_size):
+ # batched_data = batched_data_skeleton(self.data)
+ # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
+ # part_type, edges[m : m + self.batch_size])
+ # yield batched_data
+ #
+ # edges = getattr(rel.edges_neg, part_type)
+ # for m in range(0, len(edges), self.batch_size):
+ # batched_data = batched_data_skeleton(self.data)
+ # setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
+ # part_type, edges[m : m + self.batch_size])
+ # yield batched_data
+ #
+ # edges = getattr(rel.edges_pos_back, part_type)
+ # for m in range(0, len(edges), self.batch_size):
+ # batched_data = batched_data_skeleton(self.data)
+ # setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
+ # part_type, edges[m : m + self.batch_size])
+ # yield batched_data
+ #
+ # edges = getattr(rel.edges_neg_back, part_type)
+ # for m in range(0, len(), self.batch_size):
+ # batched_data = batched_data_skeleton(self.data)
+ # setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
+ # edges[m : m + self.batch_size])
+ # yield batched_data
+
+ def __iter__(self) -> BatchedData:
+ gen = self.shuffle_iter() \
+ if self.shuffle \
+ else self.iter_base()
+
+ for batched_data in gen:
+ yield batched_data
+
+ def iter_base(self) -> BatchedData:
+ for i, fam in enumerate(self.data.relation_families):
+ for k, rel in enumerate(fam.relation_types):
+ for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
+ for part_type in ['train', 'val', 'test']:
+ edges = getattr(getattr(rel, edge_type), part_type)
+ if self.shuffle:
+ perm = torch.randperm(len(edges))
+ edges = edges[perm]
+ for m in range(0, len(edges), self.batch_size):
+ batched_data = batched_data_skeleton(self.data)
+ setattr(getattr(batched_data.relation_families[i].relation_types[k],
+ edge_type), part_type, edges[m : m + self.batch_size])
+ yield batched_data
+
+ def shuffle_iter(self) -> BatchedData:
+ res = list(self.iter_base())
+ random.shuffle(res)
+ for batched_data in res:
+ yield batched_data
+
+ @staticmethod
+ def _check_params(data, batch_size):
+ if not isinstance(data, PreparedData):
+ raise TypeError('data must be an instance of PreparedData')
+
+ if not isinstance(batch_size, int):
+ raise TypeError('batch_size must be an int')
diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py
new file mode 100644
index 0000000..25d9c5f
--- /dev/null
+++ b/src/icosagon/declayer.py
@@ -0,0 +1,124 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .data import Data
+from .trainprep import PreparedData, \
+ TrainValTest
+from typing import Type, \
+ List, \
+ Callable, \
+ Union, \
+ Dict, \
+ Tuple
+from .decode import DEDICOMDecoder
+from dataclasses import dataclass
+import time
+from .databatch import BatchedDataPointer
+
+
+@dataclass
+class RelationPredictions(object):
+ edges_pos: TrainValTest
+ edges_neg: TrainValTest
+ edges_back_pos: TrainValTest
+ edges_back_neg: TrainValTest
+
+
+@dataclass
+class RelationFamilyPredictions(object):
+ relation_types: List[RelationPredictions]
+
+
+@dataclass
+class Predictions(object):
+ relation_families: List[RelationFamilyPredictions]
+
+
+class DecodeLayer(torch.nn.Module):
+ def __init__(self,
+ input_dim: List[int],
+ data: PreparedData,
+ keep_prob: float = 1.,
+ activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
+ batched_data_pointer: BatchedDataPointer = None,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ if not isinstance(input_dim, list):
+ raise TypeError('input_dim must be a List')
+
+ if len(input_dim) != len(data.node_types):
+ raise ValueError('input_dim must have length equal to num_node_types')
+
+ if not all([ a == input_dim[0] for a in input_dim ]):
+ raise ValueError('All elements of input_dim must have the same value')
+
+ if not isinstance(data, PreparedData):
+ raise TypeError('data must be an instance of PreparedData')
+
+ if batched_data_pointer is not None and \
+ not isinstance(batched_data_pointer, BatchedDataPointer):
+ raise TypeError('batched_data_pointer must be an instance of BatchedDataPointer')
+
+ # if batched_data_pointer is not None and not batched_data_pointer.compatible_with(data):
+ # raise ValueError('batched_data_pointer must be compatible with data')
+
+ self.input_dim = input_dim[0]
+ self.output_dim = 1
+ self.data = data
+ self.keep_prob = keep_prob
+ self.activation = activation
+ self.batched_data_pointer = batched_data_pointer
+
+ self.decoders = None
+ self.build()
+
+ def build(self) -> None:
+ self.decoders = torch.nn.ModuleList()
+ for fam in self.data.relation_families:
+ dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
+ self.keep_prob, self.activation)
+ self.decoders.append(dec)
+
+ def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec):
+ start_time = time.time()
+ pred = []
+ for p in edge_list_attr_names:
+ tvt = []
+ for t in ['train', 'val', 'test']:
+ # print('r:', r)
+ edges = getattr(getattr(r, p), t)
+ inputs_row = last_layer_repr[row][edges[:, 0]]
+ inputs_column = last_layer_repr[column][edges[:, 1]]
+ tvt.append(dec(inputs_row, inputs_column, k))
+ tvt = TrainValTest(*tvt)
+ pred.append(tvt)
+ # print('DecodeLayer._get_tvt() took:', time.time() - start_time)
+ return pred
+
+ def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]:
+ t = time.time()
+ res = []
+ data = self.batched_data_pointer.batched_data \
+ if self.batched_data_pointer is not None \
+ else self.data
+ for i, fam in enumerate(data.relation_families):
+ fam_pred = []
+ for k, r in enumerate(fam.relation_types):
+ pred = []
+ pred += self._get_tvt(r, ['edges_pos', 'edges_neg'],
+ r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i])
+ pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'],
+ r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i])
+ pred = RelationPredictions(*pred)
+ fam_pred.append(pred)
+ fam_pred = RelationFamilyPredictions(fam_pred)
+ res.append(fam_pred)
+ res = Predictions(res)
+ # print('DecodeLayer.forward() took', time.time() - t)
+ return res
diff --git a/src/icosagon/decode.py b/src/icosagon/decode.py
new file mode 100644
index 0000000..00df8b2
--- /dev/null
+++ b/src/icosagon/decode.py
@@ -0,0 +1,123 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .weights import init_glorot
+from .dropout import dropout
+
+
+class DEDICOMDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, keep_prob=1.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.global_interaction = torch.nn.Parameter(init_glorot(input_dim, input_dim))
+ self.local_variation = torch.nn.ParameterList([
+ torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
+ for _ in range(num_relation_types)
+ ])
+
+ def forward(self, inputs_row, inputs_col, relation_index):
+ inputs_row = dropout(inputs_row, self.keep_prob)
+ inputs_col = dropout(inputs_col, self.keep_prob)
+
+ relation = torch.diag(self.local_variation[relation_index])
+
+ product1 = torch.mm(inputs_row, relation)
+ product2 = torch.mm(product1, self.global_interaction)
+ product3 = torch.mm(product2, relation)
+ rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+
+ return self.activation(rec)
+
+
+class DistMultDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, keep_prob=1.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.relation = torch.nn.ParameterList([
+ torch.nn.Parameter(torch.flatten(init_glorot(input_dim, 1))) \
+ for _ in range(num_relation_types)
+ ])
+
+ def forward(self, inputs_row, inputs_col, relation_index):
+ inputs_row = dropout(inputs_row, self.keep_prob)
+ inputs_col = dropout(inputs_col, self.keep_prob)
+
+ relation = torch.diag(self.relation[relation_index])
+
+ intermediate_product = torch.mm(inputs_row, relation)
+ rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+
+ return self.activation(rec)
+
+
+class BilinearDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, keep_prob=1.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.relation = torch.nn.ParameterList([
+ torch.nn.Parameter(init_glorot(input_dim, input_dim)) \
+ for _ in range(num_relation_types)
+ ])
+
+ def forward(self, inputs_row, inputs_col, relation_index):
+ inputs_row = dropout(inputs_row, self.keep_prob)
+ inputs_col = dropout(inputs_col, self.keep_prob)
+
+ intermediate_product = torch.mm(inputs_row, self.relation[relation_index])
+ rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+
+ return self.activation(rec)
+
+
+class InnerProductDecoder(torch.nn.Module):
+ """DEDICOM Tensor Factorization Decoder model layer for link prediction."""
+ def __init__(self, input_dim, num_relation_types, keep_prob=1.,
+ activation=torch.sigmoid, **kwargs):
+
+ super().__init__(**kwargs)
+ self.input_dim = input_dim
+ self.num_relation_types = num_relation_types
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+
+ def forward(self, inputs_row, inputs_col, _):
+ inputs_row = dropout(inputs_row, self.keep_prob)
+ inputs_col = dropout(inputs_col, self.keep_prob)
+
+ rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+ rec = torch.flatten(rec)
+
+ return self.activation(rec)
diff --git a/src/icosagon/dropout.py b/src/icosagon/dropout.py
new file mode 100644
index 0000000..63cfb58
--- /dev/null
+++ b/src/icosagon/dropout.py
@@ -0,0 +1,42 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .normalize import _sparse_coo_tensor
+
+
+def dropout_sparse(x, keep_prob):
+ x = x.coalesce()
+ i = x._indices()
+ v = x._values()
+ size = x.size()
+
+ n = keep_prob + torch.rand(len(v))
+ n = torch.floor(n).to(torch.bool)
+ i = i[:,n]
+ v = v[n]
+ x = _sparse_coo_tensor(i, v, size=size)
+
+ return x * (1./keep_prob)
+
+
+def dropout_dense(x, keep_prob):
+ # print('dropout_dense()')
+ x = x.clone()
+ i = torch.nonzero(x)
+
+ n = keep_prob + torch.rand(len(i))
+ n = (1. - torch.floor(n)).to(torch.bool)
+ x[i[n, 0], i[n, 1]] = 0.
+
+ return x * (1./keep_prob)
+
+
+def dropout(x, keep_prob):
+ if x.is_sparse:
+ return dropout_sparse(x, keep_prob)
+ else:
+ return dropout_dense(x, keep_prob)
diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py
new file mode 100644
index 0000000..038e2fc
--- /dev/null
+++ b/src/icosagon/fastconv.py
@@ -0,0 +1,255 @@
+from typing import List, \
+ Union, \
+ Callable
+from .data import Data, \
+ RelationFamily
+from .trainprep import PreparedData, \
+ PreparedRelationFamily
+import torch
+from .weights import init_glorot
+from .normalize import _sparse_coo_tensor
+import types
+
+
+def _sparse_diag_cat(matrices: List[torch.Tensor]):
+ if len(matrices) == 0:
+ raise ValueError('The list of matrices must be non-empty')
+
+ if not all(m.is_sparse for m in matrices):
+ raise ValueError('All matrices must be sparse')
+
+ if not all(len(m.shape) == 2 for m in matrices):
+ raise ValueError('All matrices must be 2D')
+
+ indices = []
+ values = []
+ row_offset = 0
+ col_offset = 0
+
+ for m in matrices:
+ ind = m._indices().clone()
+ ind[0] += row_offset
+ ind[1] += col_offset
+ indices.append(ind)
+ values.append(m._values())
+ row_offset += m.shape[0]
+ col_offset += m.shape[1]
+
+ indices = torch.cat(indices, dim=1)
+ values = torch.cat(values)
+
+ return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
+
+
+def _cat(matrices: List[torch.Tensor]):
+ if len(matrices) == 0:
+ raise ValueError('Empty list passed to _cat()')
+
+ n = sum(a.is_sparse for a in matrices)
+ if n != 0 and n != len(matrices):
+ raise ValueError('All matrices must have the same layout (dense or sparse)')
+
+ if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
+ raise ValueError('All matrices must have the same dimensions apart from dimension 0')
+
+ if not matrices[0].is_sparse:
+ return torch.cat(matrices)
+
+ total_rows = sum(a.shape[0] for a in matrices)
+ indices = []
+ values = []
+ row_offset = 0
+
+ for a in matrices:
+ ind = a._indices().clone()
+ val = a._values()
+ ind[0] += row_offset
+ ind = ind.transpose(0, 1)
+ indices.append(ind)
+ values.append(val)
+ row_offset += a.shape[0]
+
+ indices = torch.cat(indices).transpose(0, 1)
+ values = torch.cat(values)
+
+ res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
+ return res
+
+
+class FastGraphConv(torch.nn.Module):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ adjacency_matrices: List[torch.Tensor],
+ keep_prob: float = 1.,
+ activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ in_channels = int(in_channels)
+ out_channels = int(out_channels)
+ if not isinstance(adjacency_matrices, list):
+ raise TypeError('adjacency_matrices must be a list')
+ if len(adjacency_matrices) == 0:
+ raise ValueError('adjacency_matrices must not be empty')
+ if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
+ raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
+ if not all(m.is_sparse for m in adjacency_matrices):
+ raise ValueError('adjacency_matrices elements must be sparse')
+ keep_prob = float(keep_prob)
+ if not isinstance(activation, types.FunctionType):
+ raise TypeError('activation must be a function')
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.adjacency_matrices = adjacency_matrices
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.num_row_nodes = len(adjacency_matrices[0])
+ self.num_relation_types = len(adjacency_matrices)
+
+ self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
+
+ self.weights = torch.cat([
+ init_glorot(in_channels, out_channels) \
+ for _ in range(self.num_relation_types)
+ ], dim=1)
+
+ def forward(self, x) -> torch.Tensor:
+ if self.keep_prob < 1.:
+ x = dropout(x, self.keep_prob)
+ res = torch.sparse.mm(x, self.weights) \
+ if x.is_sparse \
+ else torch.mm(x, self.weights)
+ res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
+ res = torch.cat(res)
+ res = torch.sparse.mm(self.adjacency_matrices, res) \
+ if self.adjacency_matrices.is_sparse \
+ else torch.mm(self.adjacency_matrices, res)
+ res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
+ if self.activation is not None:
+ res = self.activation(res)
+
+ return res
+
+
+class FastConvLayer(torch.nn.Module):
+ def __init__(self,
+ input_dim: List[int],
+ output_dim: List[int],
+ data: Union[Data, PreparedData],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ **kwargs):
+
+ super().__init__(**kwargs)
+
+ self._check_params(input_dim, output_dim, data, keep_prob,
+ rel_activation, layer_activation)
+
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.data = data
+ self.keep_prob = keep_prob
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+
+ self.is_sparse = False
+ self.next_layer_repr = None
+ self.build()
+
+ def build(self):
+ self.next_layer_repr = torch.nn.ModuleList([
+ torch.nn.ModuleList() \
+ for _ in range(len(self.data.node_types))
+ ])
+ for fam in self.data.relation_families:
+ self.build_family(fam)
+
+ def build_family(self, fam) -> None:
+ if fam.node_type_row == fam.node_type_column:
+ self.build_fam_one_node_type(fam)
+ else:
+ self.build_fam_two_node_types(fam)
+
+ def build_fam_one_node_type(self, fam) -> None:
+ adjacency_matrices = [
+ r.adjacency_matrix \
+ for r in fam.relation_types
+ ]
+ conv = FastGraphConv(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row],
+ adjacency_matrices,
+ self.keep_prob,
+ self.rel_activation)
+ conv.input_node_type = fam.node_type_column
+ self.next_layer_repr[fam.node_type_row].append(conv)
+
+ def build_fam_two_node_types(self, fam) -> None:
+ adjacency_matrices = [
+ r.adjacency_matrix \
+ for r in fam.relation_types \
+ if r.adjacency_matrix is not None
+ ]
+
+ adjacency_matrices_backward = [
+ r.adjacency_matrix_backward \
+ for r in fam.relation_types \
+ if r.adjacency_matrix_backward is not None
+ ]
+
+ conv = FastGraphConv(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row],
+ adjacency_matrices,
+ self.keep_prob,
+ self.rel_activation)
+
+ conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
+ self.output_dim[fam.node_type_column],
+ adjacency_matrices_backward,
+ self.keep_prob,
+ self.rel_activation)
+
+ conv.input_node_type = fam.node_type_column
+ conv_backward.input_node_type = fam.node_type_row
+
+ self.next_layer_repr[fam.node_type_row].append(conv)
+ self.next_layer_repr[fam.node_type_column].append(conv_backward)
+
+ def forward(self, prev_layer_repr):
+ next_layer_repr = [ [] \
+ for _ in range(len(self.data.node_types)) ]
+ for output_node_type in range(len(self.data.node_types)):
+ for conv in self.next_layer_repr[output_node_type]:
+ rep = conv(prev_layer_repr[conv.input_node_type])
+ rep = torch.sum(rep, dim=0)
+ rep = torch.nn.functional.normalize(rep, p=2, dim=1)
+ next_layer_repr[output_node_type].append(rep)
+ if len(next_layer_repr[output_node_type]) == 0:
+ next_layer_repr[output_node_type] = \
+ torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
+ else:
+ next_layer_repr[output_node_type] = \
+ sum(next_layer_repr[output_node_type])
+ next_layer_repr[output_node_type] = \
+ self.layer_activation(next_layer_repr[output_node_type])
+ return next_layer_repr
+
+ @staticmethod
+ def _check_params(input_dim, output_dim, data, keep_prob,
+ rel_activation, layer_activation):
+
+ if not isinstance(input_dim, list):
+ raise ValueError('input_dim must be a list')
+
+ if not output_dim:
+ raise ValueError('output_dim must be specified')
+
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim] * len(data.node_types)
+
+ if not isinstance(data, Data) and not isinstance(data, PreparedData):
+ raise ValueError('data must be of type Data or PreparedData')
diff --git a/src/icosagon/fastloop.py b/src/icosagon/fastloop.py
new file mode 100644
index 0000000..f955932
--- /dev/null
+++ b/src/icosagon/fastloop.py
@@ -0,0 +1,166 @@
+from .fastmodel import FastModel
+from .trainprep import PreparedData
+import torch
+from typing import Callable
+from types import FunctionType
+import time
+import random
+
+
+class FastBatcher(object):
+ def __init__(self, prep_d: PreparedData, batch_size: int,
+ shuffle: bool, generator: torch.Generator,
+ part_type: str) -> None:
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instance of PreparedData')
+
+ if not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ if part_type not in ['train', 'val', 'test']:
+ raise ValueError('part_type must be set to train, val or test')
+
+ self.prep_d = prep_d
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+ self.generator = generator
+ self.part_type = part_type
+
+ self.edges = None
+ self.targets = None
+ self.build()
+
+ def build(self):
+ self.edges = []
+ self.targets = []
+
+ for fam in self.prep_d.relation_families:
+ edges = []
+ targets = []
+ for i, rel in enumerate(fam.relation_types):
+
+ edges_pos = getattr(rel.edges_pos, self.part_type)
+ edges_neg = getattr(rel.edges_neg, self.part_type)
+ edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
+ edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
+
+ e = torch.cat([ edges_pos,
+ torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
+ e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
+ t = torch.ones(len(e))
+ edges.append(e)
+ targets.append(t)
+
+ e = torch.cat([ edges_neg,
+ torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
+ e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
+ t = torch.zeros(len(e))
+ edges.append(e)
+ targets.append(t)
+
+ edges = torch.cat(edges)
+ targets = torch.cat(targets)
+
+ self.edges.append(edges)
+ self.targets.append(targets)
+
+ # print(self.edges)
+ # print(self.targets)
+
+ if self.shuffle:
+ self.shuffle_families()
+
+ def shuffle_families(self):
+ for i in range(len(self.edges)):
+ edges = self.edges[i]
+ targets = self.targets[i]
+ order = torch.randperm(len(edges), generator=self.generator)
+ self.edges[i] = edges[order]
+ self.targets[i] = targets[order]
+
+ def __iter__(self):
+ offsets = [ 0 for _ in self.edges ]
+
+ while True:
+ choice = [ i for i in range(len(offsets)) \
+ if offsets[i] < len(self.edges[i]) ]
+ if len(choice) == 0:
+ break
+ fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
+ ofs = offsets[fam_idx]
+ edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
+ targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
+ offsets[fam_idx] += self.batch_size
+ yield (fam_idx, edges, targets)
+
+
+class FastLoop(object):
+ def __init__(
+ self,
+ model: FastModel,
+ lr: float = 0.001,
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
+ torch.nn.functional.binary_cross_entropy_with_logits,
+ batch_size: int = 100,
+ shuffle: bool = True,
+ generator: torch.Generator = None) -> None:
+
+ self._check_params(model, loss, generator)
+
+ self.model = model
+ self.lr = float(lr)
+ self.loss = loss
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+ self.generator = generator or torch.default_generator
+
+ self.opt = None
+
+ self.build()
+
+ def _check_params(self, model, loss, generator):
+ if not isinstance(model, FastModel):
+ raise TypeError('model must be an instance of FastModel')
+
+ if not isinstance(loss, FunctionType):
+ raise TypeError('loss must be a function')
+
+ if generator is not None and not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ def build(self) -> None:
+ opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+ self.opt = opt
+
+ def run_epoch(self):
+ prep_d = self.model.prep_d
+
+ batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
+ shuffle = self.shuffle, generator=self.generator)
+ # pred = self.model(None)
+ # n = len(list(iter(batch)))
+ loss_sum = 0
+ for fam_idx, edges, targets in batcher:
+ self.opt.zero_grad()
+ pred = self.model(None)
+
+ # process pred, get input and targets
+ input = pred[fam_idx][edges[:, 0], edges[:, 1]]
+
+ loss = self.loss(input, targets)
+ loss.backward()
+ self.opt.step()
+ loss_sum += loss.detach().cpu().item()
+ return loss_sum
+
+
+ def train(self, max_epochs):
+ best_loss = None
+ best_epoch = None
+ for i in range(max_epochs):
+ loss = self.run_epoch()
+ if best_loss is None or loss < best_loss:
+ best_loss = loss
+ best_epoch = i
+ return loss, best_loss, best_epoch
diff --git a/src/icosagon/fastmodel.py b/src/icosagon/fastmodel.py
new file mode 100644
index 0000000..a68fe58
--- /dev/null
+++ b/src/icosagon/fastmodel.py
@@ -0,0 +1,79 @@
+from .fastconv import FastConvLayer
+from .bulkdec import BulkDecodeLayer
+from .input import OneHotInputLayer
+from .trainprep import PreparedData
+import torch
+import types
+from typing import List, \
+ Union, \
+ Callable
+
+
+class FastModel(torch.nn.Module):
+ def __init__(self, prep_d: PreparedData,
+ layer_dimensions: List[int] = [32, 64],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ self._check_params(prep_d, layer_dimensions, rel_activation,
+ layer_activation, dec_activation)
+
+ self.prep_d = prep_d
+ self.layer_dimensions = layer_dimensions
+ self.keep_prob = float(keep_prob)
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+ self.dec_activation = dec_activation
+
+ self.seq = None
+ self.build()
+
+ def build(self):
+ in_layer = OneHotInputLayer(self.prep_d)
+ last_output_dim = in_layer.output_dim
+ seq = [ in_layer ]
+
+ for dim in self.layer_dimensions:
+ conv_layer = FastConvLayer(input_dim = last_output_dim,
+ output_dim = [dim] * len(self.prep_d.node_types),
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ rel_activation = self.rel_activation,
+ layer_activation = self.layer_activation)
+ last_output_dim = conv_layer.output_dim
+ seq.append(conv_layer)
+
+ dec_layer = BulkDecodeLayer(input_dim = last_output_dim,
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ activation = self.dec_activation)
+ seq.append(dec_layer)
+
+ seq = torch.nn.Sequential(*seq)
+ self.seq = seq
+
+ def forward(self, _):
+ return self.seq(None)
+
+ def _check_params(self, prep_d, layer_dimensions, rel_activation,
+ layer_activation, dec_activation):
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instanced of PreparedData')
+
+ if not isinstance(layer_dimensions, list):
+ raise TypeError('layer_dimensions must be a list')
+
+ if not isinstance(rel_activation, types.FunctionType):
+ raise TypeError('rel_activation must be a function')
+
+ if not isinstance(layer_activation, types.FunctionType):
+ raise TypeError('layer_activation must be a function')
+
+ if not isinstance(dec_activation, types.FunctionType):
+ raise TypeError('dec_activation must be a function')
diff --git a/src/icosagon/input.py b/src/icosagon/input.py
new file mode 100644
index 0000000..3bf5824
--- /dev/null
+++ b/src/icosagon/input.py
@@ -0,0 +1,79 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from typing import Union, \
+ List
+from .data import Data
+
+
+class InputLayer(torch.nn.Module):
+ def __init__(self, data: Data, output_dim: Union[int, List[int]] = None,
+ **kwargs) -> None:
+
+ output_dim = output_dim or \
+ list(map(lambda a: a.count, data.node_types))
+
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim,] * len(data.node_types)
+
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.data = data
+
+ self.is_sparse=False
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = []
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.rand(nt.count, self.output_dim[i])
+ reps = torch.nn.Parameter(reps)
+ self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self, x) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
+
+
+class OneHotInputLayer(torch.nn.Module):
+ def __init__(self, data: Data, **kwargs) -> None:
+ output_dim = [ a.count for a in data.node_types ]
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.data = data
+
+ self.is_sparse=True
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = torch.nn.ParameterList()
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.eye(nt.count).to_sparse()
+ reps = torch.nn.Parameter(reps, requires_grad=False)
+ # self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self, x) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'Icosagon one-hot input layer\n'
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
diff --git a/src/icosagon/model.py b/src/icosagon/model.py
new file mode 100644
index 0000000..1c9e413
--- /dev/null
+++ b/src/icosagon/model.py
@@ -0,0 +1,77 @@
+from .data import Data
+from typing import List, \
+ Callable
+from .trainprep import PreparedData
+import torch
+from .convlayer import DecagonLayer
+from .input import OneHotInputLayer
+from types import FunctionType
+from .declayer import DecodeLayer
+from .batch import PredictionsBatch
+
+
+class Model(torch.nn.Module):
+ def __init__(self, prep_d: PreparedData,
+ layer_dimensions: List[int] = [32, 64],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instance of PreparedData')
+
+ if not isinstance(layer_dimensions, list):
+ raise TypeError('layer_dimensions must be a list')
+
+ keep_prob = float(keep_prob)
+
+ if not isinstance(rel_activation, FunctionType):
+ raise TypeError('rel_activation must be a function')
+
+ if not isinstance(layer_activation, FunctionType):
+ raise TypeError('layer_activation must be a function')
+
+ if not isinstance(dec_activation, FunctionType):
+ raise TypeError('dec_activation must be a function')
+
+ self.prep_d = prep_d
+ self.layer_dimensions = layer_dimensions
+ self.keep_prob = keep_prob
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+ self.dec_activation = dec_activation
+
+ self.seq = None
+
+ self.build()
+
+ def build(self):
+ in_layer = OneHotInputLayer(self.prep_d)
+ last_output_dim = in_layer.output_dim
+ seq = [ in_layer ]
+
+ for dim in self.layer_dimensions:
+ conv_layer = DecagonLayer(input_dim = last_output_dim,
+ output_dim = [ dim ] * len(self.prep_d.node_types),
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ rel_activation = self.rel_activation,
+ layer_activation = self.layer_activation)
+ last_output_dim = conv_layer.output_dim
+ seq.append(conv_layer)
+
+ dec_layer = DecodeLayer(input_dim = last_output_dim,
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ activation = self.dec_activation)
+ seq.append(dec_layer)
+
+ seq = torch.nn.Sequential(*seq)
+ self.seq = seq
+
+ def forward(self, _):
+ return self.seq(None)
diff --git a/src/icosagon/normalize.py b/src/icosagon/normalize.py
new file mode 100644
index 0000000..e13fb05
--- /dev/null
+++ b/src/icosagon/normalize.py
@@ -0,0 +1,145 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import numpy as np
+import scipy.sparse as sp
+import torch
+
+
+def _check_tensor(adj_mat):
+ if not isinstance(adj_mat, torch.Tensor):
+ raise ValueError('adj_mat must be a torch.Tensor')
+
+
+def _check_sparse(adj_mat):
+ if not adj_mat.is_sparse:
+ raise ValueError('adj_mat must be sparse')
+
+
+def _check_dense(adj_mat):
+ if adj_mat.is_sparse:
+ raise ValueError('adj_mat must be dense')
+
+
+def _check_square(adj_mat):
+ if len(adj_mat.shape) != 2 or \
+ adj_mat.shape[0] != adj_mat.shape[1]:
+ raise ValueError('adj_mat must be a square matrix')
+
+
+def _check_2d(adj_mat):
+ if len(adj_mat.shape) != 2:
+ raise ValueError('adj_mat must be a square matrix')
+
+
+def _sparse_coo_tensor(indices, values, size):
+ ctor = { torch.float32: torch.sparse.FloatTensor,
+ torch.float32: torch.sparse.DoubleTensor,
+ torch.uint8: torch.sparse.ByteTensor,
+ torch.long: torch.sparse.LongTensor,
+ torch.int: torch.sparse.IntTensor,
+ torch.short: torch.sparse.ShortTensor,
+ torch.bool: torch.sparse.ByteTensor }[values.dtype]
+ return ctor(indices, values, size)
+
+
+def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = adj_mat.coalesce()
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+
+ eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
+ device=adj_mat.device).view(1, -1)
+ eye_indices = torch.cat((eye_indices, eye_indices), 0)
+ eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
+ device=adj_mat.device)
+
+ indices = torch.cat((indices, eye_indices), 1)
+ values = torch.cat((values, eye_values), 0)
+
+ adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = add_eye_sparse(adj_mat)
+ adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_dense(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
+ device=adj_mat.device)
+ adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_square(adj_mat)
+
+ if adj_mat.is_sparse:
+ return norm_adj_mat_one_node_type_sparse(adj_mat)
+ else:
+ return norm_adj_mat_one_node_type_dense(adj_mat)
+
+
+def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_2d(adj_mat)
+
+ adj_mat = adj_mat.coalesce()
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+ degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
+ degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
+ degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
+ degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
+ values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
+ adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
+
+ return adj_mat
+
+
+def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_dense(adj_mat)
+ _check_2d(adj_mat)
+
+ degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
+ degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
+ degrees_row = torch.sqrt(degrees_row)
+ degrees_col = torch.sqrt(degrees_col)
+ adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
+ adj_mat = adj_mat / degrees_col
+
+ return adj_mat
+
+
+def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_2d(adj_mat)
+
+ if adj_mat.is_sparse:
+ return norm_adj_mat_two_node_types_sparse(adj_mat)
+ else:
+ return norm_adj_mat_two_node_types_dense(adj_mat)
diff --git a/src/icosagon/sampling.py b/src/icosagon/sampling.py
new file mode 100644
index 0000000..7c55944
--- /dev/null
+++ b/src/icosagon/sampling.py
@@ -0,0 +1,47 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import numpy as np
+import torch
+import torch.utils.data
+from typing import List, \
+ Union
+
+
+def fixed_unigram_candidate_sampler(
+ true_classes: Union[np.array, torch.Tensor],
+ unigrams: List[Union[int, float]],
+ distortion: float = 1.):
+
+ if isinstance(true_classes, torch.Tensor):
+ true_classes = true_classes.detach().cpu().numpy()
+
+ if isinstance(unigrams, torch.Tensor):
+ unigrams = unigrams.detach().cpu().numpy()
+
+ if len(true_classes.shape) != 2:
+ raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
+
+ num_samples = true_classes.shape[0]
+ unigrams = np.array(unigrams)
+ if distortion != 1.:
+ unigrams = unigrams.astype(np.float64) ** distortion
+ # print('unigrams:', unigrams)
+ indices = np.arange(num_samples)
+ result = np.zeros(num_samples, dtype=np.int64)
+ while len(indices) > 0:
+ # print('len(indices):', len(indices))
+ sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
+ candidates = np.array(list(sampler))
+ candidates = np.reshape(candidates, (len(indices), 1))
+ # print('candidates:', candidates)
+ # print('true_classes:', true_classes[indices, :])
+ result[indices] = candidates.T
+ mask = (candidates == true_classes[indices, :])
+ mask = mask.sum(1).astype(np.bool)
+ # print('mask:', mask)
+ indices = indices[mask]
+ return torch.tensor(result)
diff --git a/src/icosagon/trainloop.py b/src/icosagon/trainloop.py
new file mode 100644
index 0000000..40cb122
--- /dev/null
+++ b/src/icosagon/trainloop.py
@@ -0,0 +1,100 @@
+from .model import Model
+import torch
+from .batch import PredictionsBatch, \
+ flatten_predictions, \
+ gather_batch_indices
+from typing import Callable
+from types import FunctionType
+import time
+
+
+class TrainLoop(object):
+ def __init__(
+ self,
+ model: Model,
+ lr: float = 0.001,
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
+ torch.nn.functional.binary_cross_entropy_with_logits,
+ batch_size: int = 100,
+ shuffle: bool = False,
+ generator: torch.Generator = None) -> None:
+
+ if not isinstance(model, Model):
+ raise TypeError('model must be an instance of Model')
+
+ lr = float(lr)
+
+ if not isinstance(loss, FunctionType):
+ raise TypeError('loss must be a function')
+
+ batch_size = int(batch_size)
+
+ if generator is not None and not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ self.model = model
+ self.lr = lr
+ self.loss = loss
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.generator = generator or torch.default_generator
+
+ self.opt = None
+
+ self.build()
+
+ def build(self) -> None:
+ opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+ self.opt = opt
+
+ def run_epoch(self):
+ batch = PredictionsBatch(self.model.prep_d, batch_size=self.batch_size,
+ shuffle = self.shuffle, generator=self.generator)
+ # pred = self.model(None)
+ # n = len(list(iter(batch)))
+ loss_sum = 0
+ for i, indices in enumerate(batch):
+ print('%.2f%% (%d/%d)' % (i * batch.batch_size * 100 / batch.total_edge_count, i * batch.batch_size, batch.total_edge_count))
+ t = time.time()
+ self.opt.zero_grad()
+ print('zero_grad() took:', time.time() - t)
+ t = time.time()
+ pred = self.model(None)
+ print('model() took:', time.time() - t)
+ t = time.time()
+ pred = flatten_predictions(pred)
+ print('flatten_predictions() took:', time.time() - t)
+ # batch = PredictionsBatch(pred, batch_size=self.batch_size, shuffle=True)
+ # seed = torch.rand(1).item()
+ # rng_state = torch.get_rng_state()
+ # torch.manual_seed(seed)
+ #it = iter(batch)
+ #torch.set_rng_state(rng_state)
+ #for k in range(i):
+ #_ = next(it)
+ #(input, target) = next(it)
+ t = time.time()
+ (input, target) = gather_batch_indices(pred, indices)
+ print('gather_batch_indices() took:', time.time() - t)
+ t = time.time()
+ loss = self.loss(input, target)
+ print('loss() took:', time.time() - t)
+ t = time.time()
+ loss.backward()
+ print('backward() took:', time.time() - t)
+ t = time.time()
+ self.opt.step()
+ print('step() took:', time.time() - t)
+ loss_sum += loss.detach().cpu().item()
+ return loss_sum
+
+
+ def train(self, max_epochs):
+ best_loss = None
+ best_epoch = None
+ for i in range(max_epochs):
+ loss = self.run_epoch()
+ if best_loss is None or loss < best_loss:
+ best_loss = loss
+ best_epoch = i
+ return loss, best_loss, best_epoch
diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py
new file mode 100644
index 0000000..c49300a
--- /dev/null
+++ b/src/icosagon/trainprep.py
@@ -0,0 +1,215 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .sampling import fixed_unigram_candidate_sampler
+import torch
+from dataclasses import dataclass, \
+ field
+from typing import Any, \
+ List, \
+ Tuple, \
+ Dict
+from .data import NodeType, \
+ RelationType, \
+ RelationTypeBase, \
+ RelationFamily, \
+ RelationFamilyBase, \
+ Data
+from collections import defaultdict
+from .normalize import norm_adj_mat_one_node_type, \
+ norm_adj_mat_two_node_types
+import numpy as np
+
+
+@dataclass
+class TrainValTest(object):
+ train: Any
+ val: Any
+ test: Any
+
+
+@dataclass
+class PreparedRelationType(RelationTypeBase):
+ edges_pos: TrainValTest
+ edges_neg: TrainValTest
+ edges_back_pos: TrainValTest
+ edges_back_neg: TrainValTest
+
+
+@dataclass
+class PreparedRelationFamily(RelationFamilyBase):
+ relation_types: List[PreparedRelationType]
+
+
+@dataclass
+class PreparedData(object):
+ node_types: List[NodeType]
+ relation_families: List[PreparedRelationFamily]
+
+
+def _empty_edge_list_tvt() -> TrainValTest:
+ return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ])
+
+
+def train_val_test_split_edges(edges: torch.Tensor,
+ ratios: TrainValTest) -> TrainValTest:
+
+ if not isinstance(edges, torch.Tensor):
+ raise ValueError('edges must be a torch.Tensor')
+
+ if len(edges.shape) != 2 or edges.shape[1] != 2:
+ raise ValueError('edges shape must be (num_edges, 2)')
+
+ if not isinstance(ratios, TrainValTest):
+ raise ValueError('ratios must be a TrainValTest')
+
+ if ratios.train + ratios.val + ratios.test != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * ratios.train)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (ratios.train + ratios.val))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+
+ return TrainValTest(edges_train, edges_val, edges_test)
+
+
+def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ if adj_mat.is_sparse:
+ adj_mat = adj_mat.coalesce()
+ degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
+ device=adj_mat.device)
+ degrees = degrees.index_add(0, adj_mat.indices()[1],
+ torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
+ device=adj_mat.device))
+ edges_pos = adj_mat.indices().transpose(0, 1)
+ else:
+ degrees = adj_mat.sum(0)
+ edges_pos = torch.nonzero(adj_mat)
+ return edges_pos, degrees
+
+
+def prepare_adj_mat(adj_mat: torch.Tensor,
+ ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
+
+ if not isinstance(adj_mat, torch.Tensor):
+ raise ValueError('adj_mat must be a torch.Tensor')
+
+ edges_pos, degrees = get_edges_and_degrees(adj_mat)
+
+ neg_neighbors = fixed_unigram_candidate_sampler(
+ edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
+ print(edges_pos.dtype)
+ print(neg_neighbors.dtype)
+ edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
+
+ edges_pos = train_val_test_split_edges(edges_pos, ratios)
+ edges_neg = train_val_test_split_edges(edges_neg, ratios)
+
+ adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
+ values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype,
+ device=adj_mat.device)
+
+ return adj_mat_train, edges_pos, edges_neg
+
+
+def prep_rel_one_node_type(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ adj_mat = r.adjacency_matrix
+ adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ print('adj_mat_train:', adj_mat_train)
+ adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
+
+ return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
+ adj_mat_train, adj_mat_back_train, edges_pos, edges_neg,
+ edges_back_pos, edges_back_neg)
+
+
+def prep_rel_two_node_types_sym(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ adj_mat = r.adjacency_matrix
+ adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
+ edges_back_pos, edges_back_neg = \
+ _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ return PreparedRelationType(r.name, r.node_type_row,
+ r.node_type_column,
+ norm_adj_mat_two_node_types(adj_mat_train),
+ norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
+ edges_pos, edges_neg, edges_back_pos, edges_back_neg)
+
+
+def prep_rel_two_node_types_asym(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ if r.adjacency_matrix is not None:
+ adj_mat_train, edges_pos, edges_neg =\
+ prepare_adj_mat(r.adjacency_matrix, ratios)
+ else:
+ adj_mat_train, edges_pos, edges_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ if r.adjacency_matrix_backward is not None:
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ prepare_adj_mat(r.adjacency_matrix_backward, ratios)
+ else:
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ return PreparedRelationType(r.name, r.node_type_row,
+ r.node_type_column,
+ norm_adj_mat_two_node_types(adj_mat_train),
+ norm_adj_mat_two_node_types(adj_mat_back_train),
+ edges_pos, edges_neg, edges_back_pos, edges_back_neg)
+
+
+def prepare_relation_type(r: RelationType,
+ ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType:
+
+ if not isinstance(r, RelationType):
+ raise ValueError('r must be a RelationType')
+
+ if not isinstance(ratios, TrainValTest):
+ raise ValueError('ratios must be a TrainValTest')
+
+ if r.node_type_row == r.node_type_column:
+ return prep_rel_one_node_type(r, ratios)
+ elif is_symmetric:
+ return prep_rel_two_node_types_sym(r, ratios)
+ else:
+ return prep_rel_two_node_types_asym(r, ratios)
+
+
+def prepare_relation_family(fam: RelationFamily,
+ ratios: TrainValTest) -> PreparedRelationFamily:
+
+ relation_types = []
+
+ for r in fam.relation_types:
+ relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric))
+
+ return PreparedRelationFamily(fam.data, fam.name,
+ fam.node_type_row, fam.node_type_column,
+ fam.is_symmetric, fam.decoder_class,
+ relation_types)
+
+
+def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
+ if not isinstance(data, Data):
+ raise ValueError('data must be of class Data')
+
+ relation_families = [ prepare_relation_family(fam, ratios) \
+ for fam in data.relation_families ]
+
+ return PreparedData(data.node_types, relation_families)
diff --git a/src/icosagon/unused/loss.py b/src/icosagon/unused/loss.py
new file mode 100644
index 0000000..eb2e0fe
--- /dev/null
+++ b/src/icosagon/unused/loss.py
@@ -0,0 +1,44 @@
+import torch
+from icosagon.trainprep import PreparedData
+from icosagon.declayer import Predictions
+
+
+class CrossEntropyLoss(torch.nn.Module):
+ def __init__(self, data: PreparedData, partition_type: str = 'train',
+ reduction: str = 'sum', **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ if not isinstance(data, PreparedData):
+ raise TypeError('data must be an instance of PreparedData')
+
+ if partition_type not in ['train', 'val', 'test']:
+ raise ValueError('partition_type must be set to train, val or test')
+
+ if reduction not in ['sum', 'mean']:
+ raise ValueError('reduction must be set to sum or mean')
+
+ self.data = data
+ self.partition_type = partition_type
+ self.reduction = reduction
+
+ def forward(self, pred: Predictions) -> torch.Tensor:
+ input = []
+ target = []
+ for fam in pred.relation_families:
+ for rel in fam.relation_types:
+ for edge_type in ['edges_pos', 'edges_back_pos']:
+ x = getattr(getattr(rel, edge_type), self.partition_type)
+ assert len(x.shape) == 1
+ input.append(x)
+ target.append(torch.ones_like(x))
+ for edge_type in ['edges_neg', 'edges_back_neg']:
+ x = getattr(getattr(rel, edge_type), self.partition_type)
+ assert len(x.shape) == 1
+ input.append(x)
+ target.append(torch.zeros_like(x))
+ input = torch.cat(input, dim=0)
+ target = torch.cat(target, dim=0)
+ res = torch.nn.functional.binary_cross_entropy(input, target,
+ reduction=self.reduction)
+ return res
diff --git a/src/icosagon/weights.py b/src/icosagon/weights.py
new file mode 100644
index 0000000..2dcb7b4
--- /dev/null
+++ b/src/icosagon/weights.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+import numpy as np
+
+
+def init_glorot(in_channels, out_channels, dtype=torch.float32):
+ """Create a weight variable with Glorot & Bengio (AISTATS 2010)
+ initialization.
+ """
+ init_range = np.sqrt(6.0 / (in_channels + out_channels))
+ initial = -init_range + 2 * init_range * \
+ torch.rand(( in_channels, out_channels ), dtype=dtype)
+ initial = initial.requires_grad_(True)
+ return initial
diff --git a/src/torch_stablesort/dispatch.h b/src/torch_stablesort/dispatch.h
new file mode 100644
index 0000000..ef91b46
--- /dev/null
+++ b/src/torch_stablesort/dispatch.h
@@ -0,0 +1,44 @@
+#pragma once
+
+#include
+#include
+
+template class F, typename R, typename... Ts>
+R dispatch(torch::Tensor input, Ts&& ... args) {
+ switch(input.type().scalarType()) {
+ case torch::ScalarType::Double:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Float:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Half:
+ throw std::runtime_error("Half-precision float not supported");
+ case torch::ScalarType::ComplexHalf:
+ throw std::runtime_error("Half-precision complex float not supported");
+ case torch::ScalarType::ComplexFloat:
+ throw std::runtime_error("Complex float not supported");
+ case torch::ScalarType::ComplexDouble:
+ throw std::runtime_error("Complex double not supported");
+ case torch::ScalarType::Long:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Int:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Short:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Char:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Byte:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::Bool:
+ return F()(input, std::forward(args)...);
+ case torch::ScalarType::QInt32:
+ throw std::runtime_error("QInt32 not supported");
+ //case torch::ScalarType::QInt16:
+ // throw std::runtime_error("QInt16 not supported");
+ case torch::ScalarType::QInt8:
+ throw std::runtime_error("QInt8 not supported");
+ case torch::ScalarType::BFloat16:
+ throw std::runtime_error("BFloat16 not supported");
+ default:
+ throw std::runtime_error("Unknown scalar type");
+ }
+}
diff --git a/src/torch_stablesort/dispatch_test.cpp b/src/torch_stablesort/dispatch_test.cpp
new file mode 100644
index 0000000..76e043f
--- /dev/null
+++ b/src/torch_stablesort/dispatch_test.cpp
@@ -0,0 +1,31 @@
+#include
+#include
+
+template class F, typename... Ts>
+void dispatch(int x, Ts&& ...args) {
+ if (x > 0)
+ F()(std::forward(args)...);
+ else
+ F()(std::forward(args)...);
+}
+
+template
+struct bla {
+ void operator()(int&& a, char&& b, double&& c) const {
+ std::cout << sizeof(T) << " " << typeid(T).name() << " " << a << " " << b << " " << c << std::endl;
+ }
+};
+
+template
+struct bla128 {
+ void operator()(int&& a, char&& b, __float128&& c) const {
+ std::cout << sizeof(T) << " " << typeid(T).name() << " " << a << " " << b << " " << (double) c << std::endl;
+ }
+};
+
+main() {
+ std::cout << "main()" << std::endl;
+ //bla()(1, 'a', 5.5);
+ dispatch(5, 1, 'a', (__float128) 5.5);
+ dispatch(-5, 1, 'a', 5.5);
+}
diff --git a/src/torch_stablesort/openmp_test.cpp b/src/torch_stablesort/openmp_test.cpp
new file mode 100644
index 0000000..10b364e
--- /dev/null
+++ b/src/torch_stablesort/openmp_test.cpp
@@ -0,0 +1,8 @@
+#include
+
+main() {
+ #pragma omp parallel for
+ for (int i = 0; i < 10; i++) {
+ std::cout << i << std::endl;
+ }
+}
diff --git a/src/torch_stablesort/setup.py b/src/torch_stablesort/setup.py
new file mode 100644
index 0000000..ae03d78
--- /dev/null
+++ b/src/torch_stablesort/setup.py
@@ -0,0 +1,14 @@
+from setuptools import setup, Extension
+from torch.utils import cpp_extension
+
+setup(name='torch_stablesort',
+ py_modules=['torch_stablesort'],
+ ext_modules=[ cpp_extension.CUDAExtension( 'torch_stablesort_cpp',
+ ['torch_stablesort.cpp', 'torch_stablesort_cpu.cpp', 'torch_stablesort_cuda.cu'],
+ extra_compile_args={
+ 'cxx': ['-fopenmp', '-ggdb', '-std=c++1z'],
+ 'nvcc': [ '-I/pstore/home/adaszews/scratch/thrust',
+ '-ccbin', '/pstore/data/data_science/app/modules/anaconda3-2020.07/bin/x86_64-conda_cos6-linux-gnu-gcc',
+ '-std=c++14', '--expt-extended-lambda', '-O99']
+ } ) ],
+ cmdclass={'build_ext': cpp_extension.BuildExtension})
diff --git a/src/torch_stablesort/torch_stablesort.cpp b/src/torch_stablesort/torch_stablesort.cpp
new file mode 100644
index 0000000..bbcac3c
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort.cpp
@@ -0,0 +1,30 @@
+#include
+
+#include
+#include
+#include
+
+#include "torch_stablesort_cuda.h"
+#include "torch_stablesort_cpu.h"
+
+std::vector stable_sort(
+ torch::Tensor input,
+ int dim = -1,
+ bool descending = false,
+ torch::optional> out = torch::nullopt) {
+
+ switch (input.device().type()) {
+ case torch::DeviceType::CUDA:
+ return dispatch_cuda(input, dim, descending, out);
+ case torch::DeviceType::CPU:
+ return dispatch_cpu(input, dim, descending, out);
+ default:
+ throw std::runtime_error("Unsupported device type");
+ }
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("stable_sort", &stable_sort, "Stable sort",
+ py::arg("input"), py::arg("dim") = -1, py::arg("descending") = false,
+ py::arg("out") = nullptr);
+}
diff --git a/src/torch_stablesort/torch_stablesort.md b/src/torch_stablesort/torch_stablesort.md
new file mode 100644
index 0000000..e15d2ee
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort.md
@@ -0,0 +1,29 @@
+# torch_stablesort
+
+## Introduction
+
+### Stable sorting algorithms
+
+Stable sort algorithms sort repeated elements in the same order that they appear in the input. When sorting some kinds of data, only part of the data is examined when determining the sort order. For example, in the card sorting example to the right, the cards are being sorted by their rank, and their suit is being ignored. This allows the possibility of multiple different correctly sorted versions of the original list. Stable sorting algorithms choose one of these, according to the following rule: if two items compare as equal, like the two 5 cards, then their relative order will be preserved, so that if one came before the other in the input, it will also come before the other in the output.
+
+### PyTorch
+
+PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab. It is free and open-source software released under the Modified BSD license.
+
+### PyTorch Extensions
+
+PyTorch provides a plethora of operations related to neural networks, arbitrary tensor algebra, data wrangling and other purposes. However, you may still find yourself in need of a more customized operation. For example, you might want to use a novel activation function you found in a paper, or implement an operation you developed as part of your research.
+
+The easiest way of integrating such a custom operation in PyTorch is to write it in Python by extending Function and Module as outlined here. This gives you the full power of automatic differentiation (spares you from writing derivative functions) as well as the usual expressiveness of Python. However, there may be times when your operation is better implemented in C++. For example, your code may need to be really fast because it is called very frequently in your model or is very expensive even for few calls. Another plausible reason is that it depends on or interacts with other C or C++ libraries. To address such cases, PyTorch provides a very easy way of writing custom C++ extensions.
+
+## Implementation
+
+### setup.py
+
+
+
+### dispatch.h
+
+### torch_stablesort.cpp
+
+### torch_stablesort.py
diff --git a/src/torch_stablesort/torch_stablesort.py b/src/torch_stablesort/torch_stablesort.py
new file mode 100644
index 0000000..83b7574
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort.py
@@ -0,0 +1,28 @@
+import torch
+import torch_stablesort_cpp
+
+
+class StableSort(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, dim=-1, descending=False, out=None):
+ values, indices = \
+ torch_stablesort_cpp.stable_sort(input, dim, descending, out)
+ ctx.save_for_backward(input, indices, torch.tensor(dim))
+ return values, indices.detach()
+
+ @staticmethod
+ def backward(ctx, grad_values, grad_indices):
+ input, indices, dim = ctx.saved_variables
+ # print('backward(), grad_indices:', grad_indices, 'indices:', indices,
+ # 'grad_values:', grad_values)
+
+ res = torch.gather(grad_values, dim, indices)
+
+ # res = torch.empty_like(grad_values)
+ # print('here')
+ # res = res.view(-1, res.size(-1))
+ # indices = indices.view(-1, indices.size(-1))
+ # torch.repeat_interleave(torch.arange(indices.size(0))
+ # res[indices] = grad_values # + grad_indices
+ # print('here 2')
+ return res
diff --git a/src/torch_stablesort/torch_stablesort_cpu.cpp b/src/torch_stablesort/torch_stablesort_cpu.cpp
new file mode 100644
index 0000000..12cef54
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort_cpu.cpp
@@ -0,0 +1,115 @@
+#include
+
+#include "dispatch.h"
+#include "torch_stablesort_cpu.h"
+
+template
+struct stable_sort_impl {
+ std::vector operator()(
+ torch::Tensor input,
+ int dim,
+ torch::optional> out
+ ) const {
+
+ if (input.is_sparse())
+ throw std::runtime_error("Sparse tensors are not supported");
+
+ if (input.device().type() != torch::DeviceType::CPU)
+ throw std::runtime_error("Only CPU tensors are supported");
+
+ if (out != torch::nullopt)
+ throw std::runtime_error("out argument is not supported");
+
+ auto in = (dim != -1) ?
+ torch::transpose(input, dim, -1) :
+ input;
+
+ auto in_sizes = in.sizes();
+
+ // std::cout << "in_sizes: " << in_sizes << std::endl;
+
+ in = in.view({ -1, in.size(-1) }).contiguous();
+
+ auto in_outer_stride = in.stride(-2);
+ auto in_inner_stride = in.stride(-1);
+
+ auto pin = static_cast(in.data_ptr());
+
+ auto x = in.clone();
+
+ auto x_outer_stride = x.stride(-2);
+ auto x_inner_stride = x.stride(-1);
+
+ auto n_cols = x.size(1);
+ auto n_rows = x.size(0);
+ auto px = static_cast(x.data_ptr());
+
+ auto y = torch::empty({ n_rows, n_cols },
+ torch::TensorOptions().dtype(torch::kInt64));
+
+ auto y_outer_stride = y.stride(-2);
+ auto y_inner_stride = y.stride(-1);
+
+ auto py = static_cast(y.data_ptr());
+
+ #pragma omp parallel for
+ for (decltype(n_rows) i = 0; i < n_rows; i++) {
+ std::vector indices(n_cols);
+ for (decltype(n_cols) k = 0; k < n_cols; k++) {
+ indices[k] = k;
+ }
+
+ std::stable_sort(std::begin(indices), std::end(indices),
+ [pin, i, in_outer_stride, in_inner_stride](const auto &a, const auto &b) {
+ auto va = pin[i * in_outer_stride + a * in_inner_stride];
+ auto vb = pin[i * in_outer_stride + b * in_inner_stride];
+ if constexpr(descending)
+ return (vb < va);
+ else
+ return (va < vb);
+ });
+
+ for (decltype(n_cols) k = 0; k < n_cols; k++) {
+ py[i * y_outer_stride + k * y_inner_stride] = indices[k];
+ px[i * x_outer_stride + k * x_inner_stride] =
+ pin[i * in_outer_stride + indices[k] * in_inner_stride];
+ }
+ }
+
+ // std::cout << "Here" << std::endl;
+
+ x = x.view(in_sizes);
+ y = y.view(in_sizes);
+
+ x = (dim == -1) ?
+ x :
+ torch::transpose(x, dim, -1).contiguous();
+
+ y = (dim == -1) ?
+ y :
+ torch::transpose(y, dim, -1).contiguous();
+
+ // std::cout << "Here 2" << std::endl;
+
+ return { x, y };
+ }
+};
+
+template
+struct stable_sort_impl_desc: stable_sort_impl {};
+
+template
+struct stable_sort_impl_asc: stable_sort_impl {};
+
+std::vector dispatch_cpu(torch::Tensor input,
+ int dim,
+ bool descending,
+ torch::optional> out) {
+
+ if (descending)
+ return dispatch>(
+ input, dim, out);
+ else
+ return dispatch>(
+ input, dim, out);
+}
diff --git a/src/torch_stablesort/torch_stablesort_cpu.h b/src/torch_stablesort/torch_stablesort_cpu.h
new file mode 100644
index 0000000..7b54251
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort_cpu.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include
+
+#include
+#include
+
+std::vector dispatch_cpu(torch::Tensor input,
+ int dim,
+ bool descending,
+ torch::optional> out);
diff --git a/src/torch_stablesort/torch_stablesort_cuda.cu b/src/torch_stablesort/torch_stablesort_cuda.cu
new file mode 100644
index 0000000..cf9b0d7
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort_cuda.cu
@@ -0,0 +1,113 @@
+#pragma once
+
+#include
+#include
+#include
+
+#include "dispatch.h"
+#include "torch_stablesort_cuda.h"
+
+template
+struct stable_sort_impl_cuda {
+ std::vector operator()(
+ torch::Tensor input,
+ int dim,
+ torch::optional> out
+ ) const {
+
+ if (input.is_sparse())
+ throw std::runtime_error("Sparse tensors are not supported");
+
+ if (input.device().type() != torch::DeviceType::CUDA)
+ throw std::runtime_error("Only CUDA tensors are supported");
+
+ if (out != torch::nullopt)
+ throw std::runtime_error("out argument is not supported");
+
+ auto values = input.clone();
+
+ if (dim != -1)
+ values = torch::transpose(values, dim, -1);
+
+ auto orig_sizes = values.sizes();
+
+ values = values.view({ -1, values.size(-1) }).contiguous();
+
+ auto n_cols = values.size(1);
+ auto n_rows = values.size(0);
+ auto n = n_rows * n_cols;
+
+ assert(values.stride(-2) == n_cols);
+ assert(values.stride(-1) == 1);
+
+ auto values_ptr = values.data_ptr();
+
+ auto indices = torch::arange(0, n, 1, torch::TensorOptions()
+ .dtype(torch::kInt64)
+ .device(values.device())).view({ n_rows, n_cols });
+
+ assert(indices.stride(-2) == n_cols);
+ assert(indices.stride(-1) == 1);
+
+ auto indices_ptr = indices.data_ptr();
+
+ auto ind_beg = thrust::device_pointer_cast(indices_ptr);
+ auto val_beg = thrust::device_pointer_cast(values_ptr);
+
+ if (descending)
+ thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg, thrust::greater());
+ else
+ thrust::stable_sort_by_key(thrust::device, val_beg, val_beg + n, ind_beg);
+
+ thrust::device_vector segments(n);
+ thrust::constant_iterator n_cols_iter(n_cols);
+ thrust::transform(thrust::device,
+ ind_beg, ind_beg + n, n_cols_iter,
+ segments.begin(), thrust::divides());
+
+ thrust::stable_sort_by_key(thrust::device, segments.begin(),
+ segments.end(), val_beg);
+
+ thrust::transform(thrust::device,
+ ind_beg, ind_beg + n, n_cols_iter,
+ segments.begin(), thrust::divides());
+
+ thrust::stable_sort_by_key(thrust::device, segments.begin(),
+ segments.end(), ind_beg);
+
+ thrust::transform(thrust::device, ind_beg, ind_beg + n,
+ n_cols_iter, ind_beg, thrust::modulus());
+
+ cudaDeviceSynchronize();
+
+ values = values.view(orig_sizes);
+ indices = indices.view(orig_sizes);
+
+ if (dim != -1)
+ values = torch::transpose(values, dim, -1).contiguous();
+
+ if (dim != -1)
+ indices = torch::transpose(indices, dim, -1).contiguous();
+
+ return { values, indices };
+ }
+};
+
+template
+struct stable_sort_impl_desc_cuda: stable_sort_impl_cuda {};
+
+template
+struct stable_sort_impl_asc_cuda: stable_sort_impl_cuda {};
+
+std::vector dispatch_cuda(torch::Tensor input,
+ int dim,
+ bool descending,
+ torch::optional> out) {
+
+ if (descending)
+ return dispatch>(
+ input, dim, out);
+ else
+ return dispatch>(
+ input, dim, out);
+}
diff --git a/src/torch_stablesort/torch_stablesort_cuda.h b/src/torch_stablesort/torch_stablesort_cuda.h
new file mode 100644
index 0000000..e5f10e4
--- /dev/null
+++ b/src/torch_stablesort/torch_stablesort_cuda.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include
+
+#include
+#include
+
+std::vector dispatch_cuda(torch::Tensor input,
+ int dim,
+ bool descending,
+ torch::optional> out);
diff --git a/decagon_pytorch/model.py b/src/triacontagon/__init__.py
similarity index 100%
rename from decagon_pytorch/model.py
rename to src/triacontagon/__init__.py
diff --git a/src/triacontagon/batch.py b/src/triacontagon/batch.py
new file mode 100644
index 0000000..9688d03
--- /dev/null
+++ b/src/triacontagon/batch.py
@@ -0,0 +1,193 @@
+from .data import Data
+from .model import TrainingBatch
+import torch
+from functools import reduce
+
+
+def _shuffle(x: torch.Tensor) -> torch.Tensor:
+ order = torch.randperm(len(x))
+ return x[order]
+
+
+def _same_data_org(pos_data: Data, neg_data: Data):
+ if len(pos_data.vertex_types) != len(neg_data.vertex_types):
+ return False
+
+ test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
+ and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
+ for i in range(len(pos_data.vertex_types)) ]
+ if not all(test):
+ return False
+
+ if not set(pos_data.edge_types.keys()) == \
+ set(neg_data.edge_types.keys()):
+
+ return False
+
+ test = [ pos_data.edge_types[i].name == \
+ neg_data.edge_types[i].name \
+ and pos_data.edge_types[i].vertex_type_row == \
+ neg_data.edge_types[i].vertex_type_row \
+ and pos_data.edge_types[i].vertex_type_column == \
+ neg_data.edge_types[i].vertex_type_column \
+ and len(pos_data.edge_types[i].adjacency_matrices) == \
+ len(neg_data.edge_types[i].adjacency_matrices) \
+ for i in pos_data.edge_types.keys() ]
+ if not all(test):
+ return False
+
+ test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
+ len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
+ for k in range(len(pos_data.edge_types[i].adjacency_matrices)) ] \
+ for i in pos_data.edge_types.keys() ]
+ test = reduce(list.__add__, test, [])
+ if not all(test):
+ return False
+
+ return True
+
+
+class DualBatcher(object):
+ def __init__(self, pos_data: Data, neg_data: Data,
+ batch_size: int=512, shuffle: bool=True) -> None:
+
+ if not isinstance(pos_data, Data):
+ raise TypeError('pos_data must be an instance of Data')
+
+ if not isinstance(neg_data, Data):
+ raise TypeError('neg_data must be an instance of Data')
+
+ if not _same_data_org(pos_data, neg_data):
+ raise ValueError('pos_data and neg_data must have the same organization')
+
+ self.pos_data = pos_data
+ self.neg_data = neg_data
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+
+ def get_edge_lists(self, data: Data):
+ edge_types = list(data.edge_types.items())
+ edge_keys = [ a[0] for a in edge_types ]
+ edge_types = [ a[1] for a in edge_types ]
+
+ edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
+ for adj_mat in et.adjacency_matrices ] \
+ for et in edge_types ]
+
+ if self.shuffle:
+ edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
+ for edge_lst in edge_lists ]
+
+ offsets = [ [ 0 ] * len(et.adjacency_matrices) \
+ for et in edge_types ]
+
+ return (edge_keys, edge_types, edge_lists, offsets)
+
+ def get_candidates(self, edge_lists, offsets):
+ candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
+ if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
+ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
+
+ if len(candidates) == 0:
+ return None, None
+
+ edge_idx = torch.randint(0, len(candidates), (1,)).item()
+ edge_idx = candidates[edge_idx]
+ candidates = [ rel_idx \
+ for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
+ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
+
+ rel_idx = torch.randint(0, len(candidates), (1,)).item()
+ rel_idx = candidates[rel_idx]
+
+ return edge_idx, rel_idx
+
+ def take_edges(self, edge_idx, rel_idx, edge_lists, offsets,
+ edge_types, target_value):
+
+ lst = edge_lists[edge_idx][rel_idx]
+ et = edge_types[edge_idx]
+ ofs = offsets[edge_idx][rel_idx]
+ lst = lst[ofs:ofs+self.batch_size]
+ offsets[edge_idx][rel_idx] += self.batch_size
+
+ res = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
+ rel_idx, lst, torch.full(( len(lst), ), target_value,
+ dtype=torch.float32))
+
+ return res
+
+ def __iter__(self):
+ pos_edge_keys, pos_edge_types, pos_edge_lists, pos_offsets = \
+ self.get_edge_lists(self.pos_data)
+
+ neg_edge_keys, neg_edge_types, neg_edge_lists, neg_offsets = \
+ self.get_edge_lists(self.neg_data)
+
+ while True:
+ edge_idx, rel_idx = self.get_candidates(pos_edge_lists, pos_offsets)
+
+ if edge_idx is None:
+ return
+
+ pos_batch = self.take_edges(edge_idx, rel_idx, pos_edge_lists,
+ pos_offsets, pos_edge_types, 1)
+
+ neg_batch = self.take_edges(edge_idx, rel_idx, neg_edge_lists,
+ neg_offsets, neg_edge_types, 0)
+
+ yield (pos_batch, neg_batch)
+
+
+class Batcher(object):
+ def __init__(self, data: Data, batch_size: int=512,
+ shuffle: bool=True) -> None:
+
+ if not isinstance(data, Data):
+ raise TypeError('data must be an instance of Data')
+
+ self.data = data
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+
+ def __iter__(self) -> TrainingBatch:
+ edge_types = list(self.data.edge_types.values())
+
+ edge_lists = [ [ adj_mat.indices().transpose(0, 1) \
+ for adj_mat in et.adjacency_matrices ] \
+ for et in edge_types ]
+
+ if self.shuffle:
+ edge_lists = [ [ _shuffle(lst) for lst in edge_lst ] \
+ for edge_lst in edge_lists ]
+
+ offsets = [ [ 0 ] * len(et.adjacency_matrices) \
+ for et in edge_types ]
+
+ while True:
+ candidates = [ edge_idx for edge_idx, edge_ofs in enumerate(offsets) \
+ if len([ rel_idx for rel_idx, rel_ofs in enumerate(edge_ofs) \
+ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]) > 0 ]
+ if len(candidates) == 0:
+ break
+
+ edge_idx = torch.randint(0, len(candidates), (1,)).item()
+ edge_idx = candidates[edge_idx]
+ candidates = [ rel_idx \
+ for rel_idx, rel_ofs in enumerate(offsets[edge_idx]) \
+ if rel_ofs < len(edge_lists[edge_idx][rel_idx]) ]
+
+ rel_idx = torch.randint(0, len(candidates), (1,)).item()
+ rel_idx = candidates[rel_idx]
+
+ lst = edge_lists[edge_idx][rel_idx]
+ et = edge_types[edge_idx]
+ ofs = offsets[edge_idx][rel_idx]
+ lst = lst[ofs:ofs+self.batch_size]
+ offsets[edge_idx][rel_idx] += self.batch_size
+
+ b = TrainingBatch(et.vertex_type_row, et.vertex_type_column,
+ rel_idx, lst, torch.full((len(lst),), self.data.target_value,
+ dtype=torch.float32))
+
+ yield b
diff --git a/src/triacontagon/cumcount.py b/src/triacontagon/cumcount.py
new file mode 100644
index 0000000..33169f9
--- /dev/null
+++ b/src/triacontagon/cumcount.py
@@ -0,0 +1,31 @@
+import torch
+import numpy as np
+
+
+def dfill(a):
+ n = torch.numel(a)
+ b = torch.cat([
+ torch.tensor([0]),
+ torch.nonzero(a[:-1] != a[1:], as_tuple=True)[0] + 1,
+ torch.tensor([n])
+ ])
+ # print('b:',b)
+ res = torch.arange(n)[b[:-1]]
+ res = torch.repeat_interleave(res, b[1:] - b[:-1])
+ return res
+
+
+def argunsort(s):
+ n = torch.numel(s)
+ u = torch.empty(n, dtype=torch.int64)
+ u[s] = torch.arange(n)
+ return u
+
+
+def cumcount(a):
+ n = torch.numel(a)
+ s = np.argsort(a.detach().cpu().numpy(), kind='mergesort')
+ s = torch.tensor(s, device=a.device)
+ i = argunsort(s)
+ b = a[s]
+ return (torch.arange(n) - dfill(b))[i]
diff --git a/src/triacontagon/data.py b/src/triacontagon/data.py
new file mode 100644
index 0000000..48b0a58
--- /dev/null
+++ b/src/triacontagon/data.py
@@ -0,0 +1,89 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from dataclasses import dataclass
+from typing import Callable, \
+ Tuple, \
+ List
+import types
+from .util import _nonzero_sum, \
+ _diag
+import torch
+
+
+@dataclass
+class DecodingMatrices(object):
+ global_interaction: torch.Tensor
+ local_variation: torch.Tensor
+
+
+@dataclass
+class VertexType(object):
+ name: str
+ count: int
+
+
+@dataclass
+class EdgeType(object):
+ name: str
+ vertex_type_row: int
+ vertex_type_column: int
+ adjacency_matrices: List[torch.Tensor]
+ decoder_factory: Callable[[], DecodingMatrices]
+ total_connectivity: torch.Tensor
+
+
+class Data(object):
+ vertex_types: List[VertexType]
+ edge_types: List[EdgeType]
+
+ def __init__(self, target_value: int = 1) -> None:
+ self.vertex_types = []
+ self.edge_types = {}
+ self.target_value = int(target_value)
+
+ def add_vertex_type(self, name: str, count: int) -> None:
+ name = str(name)
+ count = int(count)
+ if not name:
+ raise ValueError('You must provide a non-empty vertex type name')
+ if count <= 0:
+ raise ValueError('You must provide a positive vertex count')
+ self.vertex_types.append(VertexType(name, count))
+
+ def add_edge_type(self, name: str,
+ vertex_type_row: int, vertex_type_column: int,
+ adjacency_matrices: List[torch.Tensor],
+ decoder_factory: Callable[[], DecodingMatrices]) -> None:
+
+ name = str(name)
+ vertex_type_row = int(vertex_type_row)
+ vertex_type_column = int(vertex_type_column)
+
+ if not isinstance(adjacency_matrices, list):
+ raise TypeError('adjacency_matrices must be a list of tensors')
+
+ if not callable(decoder_factory):
+ raise TypeError('decoder_factory must be callable')
+
+ if (vertex_type_row, vertex_type_column) in self.edge_types:
+ raise KeyError('Edge type for given combination of row and column already exists')
+
+ if vertex_type_row == vertex_type_column and \
+ any(torch.any(_diag(adj_mat).to(torch.bool)) \
+ for adj_mat in adjacency_matrices):
+ raise ValueError('Adjacency matrices for same row/column vertex types must have empty diagonals')
+
+ if any(adj_mat.shape[0] != self.vertex_types[vertex_type_row].count \
+ or adj_mat.shape[1] != self.vertex_types[vertex_type_column].count \
+ for adj_mat in adjacency_matrices):
+ raise ValueError('Adjacency matrices must have as many rows as row vertex type count and as many columns as column vertex type count')
+
+ total_connectivity = _nonzero_sum(adjacency_matrices)
+
+ self.edge_types[vertex_type_row, vertex_type_column] = \
+ EdgeType(name, vertex_type_row, vertex_type_column,
+ adjacency_matrices, decoder_factory, total_connectivity)
diff --git a/src/triacontagon/decode.py b/src/triacontagon/decode.py
new file mode 100644
index 0000000..d82d29d
--- /dev/null
+++ b/src/triacontagon/decode.py
@@ -0,0 +1,53 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .weights import init_glorot
+from .dropout import dropout
+from typing import Tuple, \
+ List
+
+
+def dedicom_decoder(input_dim: int, num_relation_types: int) -> \
+ Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ global_interaction = init_glorot(input_dim, input_dim)
+ local_variation = [
+ torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
+ for _ in range(num_relation_types)
+ ]
+ return (global_interaction, local_variation)
+
+
+def dist_mult_decoder(input_dim: int, num_relation_types: int) -> \
+ Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ global_interaction = torch.eye(input_dim, input_dim)
+ local_variation = [
+ torch.diag(torch.flatten(init_glorot(input_dim, 1))) \
+ for _ in range(num_relation_types)
+ ]
+ return (global_interaction, local_variation)
+
+
+def bilinear_decoder(input_dim: int, num_relation_types: int) -> \
+ Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ global_interaction = torch.eye(input_dim, input_dim)
+ local_variation = [
+ init_glorot(input_dim, input_dim) \
+ for _ in range(num_relation_types)
+ ]
+ return (global_interaction, local_variation)
+
+
+def inner_product_decoder(input_dim: int, num_relation_types: int) -> \
+ Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ global_interaction = torch.eye(input_dim, input_dim)
+ local_variation = torch.eye(input_dim, input_dim)
+ local_variation = [ local_variation ] * num_relation_types
+ return (global_interaction, local_variation)
diff --git a/src/triacontagon/deprecated/fastconv.py b/src/triacontagon/deprecated/fastconv.py
new file mode 100644
index 0000000..3acc9ef
--- /dev/null
+++ b/src/triacontagon/deprecated/fastconv.py
@@ -0,0 +1,195 @@
+from typing import List, \
+ Union, \
+ Callable
+from .data import Data, \
+ RelationFamily
+from .trainprep import PreparedData, \
+ PreparedRelationFamily
+import torch
+from .weights import init_glorot
+from .normalize import _sparse_coo_tensor
+import types
+from .util import _sparse_diag_cat,
+ _cat
+
+
+
+
+
+class FastGraphConv(torch.nn.Module):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ adjacency_matrices: List[torch.Tensor],
+ keep_prob: float = 1.,
+ activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ in_channels = int(in_channels)
+ out_channels = int(out_channels)
+ if not isinstance(adjacency_matrices, list):
+ raise TypeError('adjacency_matrices must be a list')
+ if len(adjacency_matrices) == 0:
+ raise ValueError('adjacency_matrices must not be empty')
+ if not all(isinstance(m, torch.Tensor) for m in adjacency_matrices):
+ raise TypeError('adjacency_matrices elements must be of class torch.Tensor')
+ if not all(m.is_sparse for m in adjacency_matrices):
+ raise ValueError('adjacency_matrices elements must be sparse')
+ keep_prob = float(keep_prob)
+ if not isinstance(activation, types.FunctionType):
+ raise TypeError('activation must be a function')
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.adjacency_matrices = adjacency_matrices
+ self.keep_prob = keep_prob
+ self.activation = activation
+
+ self.num_row_nodes = len(adjacency_matrices[0])
+ self.num_relation_types = len(adjacency_matrices)
+
+ self.adjacency_matrices = _sparse_diag_cat(adjacency_matrices)
+
+ self.weights = torch.cat([
+ init_glorot(in_channels, out_channels) \
+ for _ in range(self.num_relation_types)
+ ], dim=1)
+
+ def forward(self, x) -> torch.Tensor:
+ if self.keep_prob < 1.:
+ x = dropout(x, self.keep_prob)
+ res = torch.sparse.mm(x, self.weights) \
+ if x.is_sparse \
+ else torch.mm(x, self.weights)
+ res = torch.split(res, res.shape[1] // self.num_relation_types, dim=1)
+ res = torch.cat(res)
+ res = torch.sparse.mm(self.adjacency_matrices, res) \
+ if self.adjacency_matrices.is_sparse \
+ else torch.mm(self.adjacency_matrices, res)
+ res = res.view(self.num_relation_types, self.num_row_nodes, self.out_channels)
+ if self.activation is not None:
+ res = self.activation(res)
+
+ return res
+
+
+class FastConvLayer(torch.nn.Module):
+ def __init__(self,
+ input_dim: List[int],
+ output_dim: List[int],
+ data: Union[Data, PreparedData],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ **kwargs):
+
+ super().__init__(**kwargs)
+
+ self._check_params(input_dim, output_dim, data, keep_prob,
+ rel_activation, layer_activation)
+
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.data = data
+ self.keep_prob = keep_prob
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+
+ self.is_sparse = False
+ self.next_layer_repr = None
+ self.build()
+
+ def build(self):
+ self.next_layer_repr = torch.nn.ModuleList([
+ torch.nn.ModuleList() \
+ for _ in range(len(self.data.node_types))
+ ])
+ for fam in self.data.relation_families:
+ self.build_family(fam)
+
+ def build_family(self, fam) -> None:
+ if fam.node_type_row == fam.node_type_column:
+ self.build_fam_one_node_type(fam)
+ else:
+ self.build_fam_two_node_types(fam)
+
+ def build_fam_one_node_type(self, fam) -> None:
+ adjacency_matrices = [
+ r.adjacency_matrix \
+ for r in fam.relation_types
+ ]
+ conv = FastGraphConv(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row],
+ adjacency_matrices,
+ self.keep_prob,
+ self.rel_activation)
+ conv.input_node_type = fam.node_type_column
+ self.next_layer_repr[fam.node_type_row].append(conv)
+
+ def build_fam_two_node_types(self, fam) -> None:
+ adjacency_matrices = [
+ r.adjacency_matrix \
+ for r in fam.relation_types \
+ if r.adjacency_matrix is not None
+ ]
+
+ adjacency_matrices_backward = [
+ r.adjacency_matrix_backward \
+ for r in fam.relation_types \
+ if r.adjacency_matrix_backward is not None
+ ]
+
+ conv = FastGraphConv(self.input_dim[fam.node_type_column],
+ self.output_dim[fam.node_type_row],
+ adjacency_matrices,
+ self.keep_prob,
+ self.rel_activation)
+
+ conv_backward = FastGraphConv(self.input_dim[fam.node_type_row],
+ self.output_dim[fam.node_type_column],
+ adjacency_matrices_backward,
+ self.keep_prob,
+ self.rel_activation)
+
+ conv.input_node_type = fam.node_type_column
+ conv_backward.input_node_type = fam.node_type_row
+
+ self.next_layer_repr[fam.node_type_row].append(conv)
+ self.next_layer_repr[fam.node_type_column].append(conv_backward)
+
+ def forward(self, prev_layer_repr):
+ next_layer_repr = [ [] \
+ for _ in range(len(self.data.node_types)) ]
+ for output_node_type in range(len(self.data.node_types)):
+ for conv in self.next_layer_repr[output_node_type]:
+ rep = conv(prev_layer_repr[conv.input_node_type])
+ rep = torch.sum(rep, dim=0)
+ rep = torch.nn.functional.normalize(rep, p=2, dim=1)
+ next_layer_repr[output_node_type].append(rep)
+ if len(next_layer_repr[output_node_type]) == 0:
+ next_layer_repr[output_node_type] = \
+ torch.zeros(self.data.node_types[output_node_type].count, self.output_dim[output_node_type])
+ else:
+ next_layer_repr[output_node_type] = \
+ sum(next_layer_repr[output_node_type])
+ next_layer_repr[output_node_type] = \
+ self.layer_activation(next_layer_repr[output_node_type])
+ return next_layer_repr
+
+ @staticmethod
+ def _check_params(input_dim, output_dim, data, keep_prob,
+ rel_activation, layer_activation):
+
+ if not isinstance(input_dim, list):
+ raise ValueError('input_dim must be a list')
+
+ if not output_dim:
+ raise ValueError('output_dim must be specified')
+
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim] * len(data.node_types)
+
+ if not isinstance(data, Data) and not isinstance(data, PreparedData):
+ raise ValueError('data must be of type Data or PreparedData')
diff --git a/src/triacontagon/deprecated/fastdec.py b/src/triacontagon/deprecated/fastdec.py
new file mode 100644
index 0000000..ca08892
--- /dev/null
+++ b/src/triacontagon/deprecated/fastdec.py
@@ -0,0 +1,138 @@
+import torch
+from typing import List
+from .trainprep import PreparedData
+from dataclasses import dataclass
+import random
+from collections import defaultdict
+
+
+@dataclass
+class TrainingBatch(object):
+ relation_family_index: int
+ relation_type_index: int
+ node_type_row: int
+ node_type_column: int
+ edges: torch.Tensor
+
+
+class FastBatcher(object):
+ def __init__(self,
+ prep_d: PreparedData,
+ batch_size: int) -> None:
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instance of PreparedData')
+
+ self.prep_d = prep_d
+ self.batch_size = int(batch_size)
+
+ self.edges = None
+ self.build()
+
+ def build(self):
+ self.edges = []
+ for fam_idx, fam in enumerate(self.prep_d.relation_families):
+ edges = []
+ targets = []
+ edges_back = []
+ targets_back = []
+ for rel_idx, rel in enumerate(fam.relation_types):
+ edges.append(rel.edges_pos.train)
+ edges.append(rel.edges_neg.train)
+ targets.append(torch.ones(len(rel.edges_pos.train)))
+ targets.append(torch.zeros(len(rel.edges_neg.train)))
+
+ edges_back.append(rel.edges_back_pos.train)
+ edges_back.append(rel.edges_back_neg.train)
+ targets_back.apend(torch.zeros(len(rel.edges_back_pos.train)))
+ targets_back.apend(torch.zeros(len(rel.edges_back_neg.train)))
+
+ edges = torch.cat(edges)
+ targets = torch.cat(targets)
+ edges_back = torch.cat(edges_back)
+ targets_back = torch.cat(targets_back)
+
+ order = torch.randperm(len(edges))
+ edges = edges[order]
+ targets = targets[order]
+
+ order_back = torch.randperm(len(edges_back))
+ edges_back = edges_back[order_back]
+ targets_back = targets_back[order_back]
+
+ self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False,
+ 'edges': edges, 'targets': targets, 'ofs': 0})
+ self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True,
+ 'edges': edges_back, 'targets': targets_back, 'ofs': 0})
+
+ def __iter__(self):
+ while True:
+ edges = [ e for e in self.edges \
+ if e['ofs'] < len(e['edges']) ]
+ # TODO: need to finish this
+
+ def __iter_old__(self):
+ edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']
+
+ offsets = {}
+ orders = {}
+ done = {}
+
+ for fam_idx, fam in enumerate(self.prep_d.relation_families):
+ for rel_idx, rel in enumerate(fam.relation_types):
+ for et in edge_types:
+ done[fam_idx, rel_idx, et] = False
+
+ while True:
+ fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item()
+ fam = self.prep_d.relation_families[fam_idx]
+
+ rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item()
+ rel = fam.relation_types[rel_idx]
+
+ et = random.choice(edge_types)
+ edges = getattr(rel, et).train
+
+ key = (fam_idx, rel_idx, et)
+ if key not in orders:
+ orders[key] = torch.randperm(len(edges))
+ offsets[key] = 0
+
+ ord = orders[key]
+ ofs = offsets[key]
+
+ nt_row = rel.node_type_row
+ nt_col = rel.node_type_column
+
+ if 'back' in et:
+ nt_row, nt_col = nt_col, nt_row
+
+ if ofs < len(edges):
+ offsets[key] += self.batch_size
+ ord = ord[ofs:ofs+self.batch_size]
+ edges = edges[ord]
+ yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges)
+ else:
+ done[key] = True
+
+
+
+
+ for fam in self.prep_d.relation_families:
+ edges = []
+ for rel in fam.relation_types:
+ edges.append(rel.edges_pos.train)
+ edges.append(rel.edges_back_pos.train)
+ edges.append(rel.edges_neg.train)
+ edges.append(rel.edges_back_neg.train)
+ edges = torch.cat(e)
+
+
+
+class FastDecLayer(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def forward(self,
+ last_layer_repr: List[torch.Tensor],
+ training_batch: TrainingBatch):
diff --git a/src/triacontagon/deprecated/fastloop.py b/src/triacontagon/deprecated/fastloop.py
new file mode 100644
index 0000000..f955932
--- /dev/null
+++ b/src/triacontagon/deprecated/fastloop.py
@@ -0,0 +1,166 @@
+from .fastmodel import FastModel
+from .trainprep import PreparedData
+import torch
+from typing import Callable
+from types import FunctionType
+import time
+import random
+
+
+class FastBatcher(object):
+ def __init__(self, prep_d: PreparedData, batch_size: int,
+ shuffle: bool, generator: torch.Generator,
+ part_type: str) -> None:
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instance of PreparedData')
+
+ if not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ if part_type not in ['train', 'val', 'test']:
+ raise ValueError('part_type must be set to train, val or test')
+
+ self.prep_d = prep_d
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+ self.generator = generator
+ self.part_type = part_type
+
+ self.edges = None
+ self.targets = None
+ self.build()
+
+ def build(self):
+ self.edges = []
+ self.targets = []
+
+ for fam in self.prep_d.relation_families:
+ edges = []
+ targets = []
+ for i, rel in enumerate(fam.relation_types):
+
+ edges_pos = getattr(rel.edges_pos, self.part_type)
+ edges_neg = getattr(rel.edges_neg, self.part_type)
+ edges_back_pos = getattr(rel.edges_back_pos, self.part_type)
+ edges_back_neg = getattr(rel.edges_back_neg, self.part_type)
+
+ e = torch.cat([ edges_pos,
+ torch.cat([edges_back_pos[:, 1], edges_back_pos[:, 0]], dim=1) ])
+ e = torch.cat([torch.ones(len(e), 1, dtype=torch.long) * i , e ], dim=1)
+ t = torch.ones(len(e))
+ edges.append(e)
+ targets.append(t)
+
+ e = torch.cat([ edges_neg,
+ torch.cat([edges_back_neg[:, 1], edges_back_neg[:, 0]], dim=1) ])
+ e = torch.cat([ torch.ones(len(e), 1, dtype=torch.long) * i, e ], dim=1)
+ t = torch.zeros(len(e))
+ edges.append(e)
+ targets.append(t)
+
+ edges = torch.cat(edges)
+ targets = torch.cat(targets)
+
+ self.edges.append(edges)
+ self.targets.append(targets)
+
+ # print(self.edges)
+ # print(self.targets)
+
+ if self.shuffle:
+ self.shuffle_families()
+
+ def shuffle_families(self):
+ for i in range(len(self.edges)):
+ edges = self.edges[i]
+ targets = self.targets[i]
+ order = torch.randperm(len(edges), generator=self.generator)
+ self.edges[i] = edges[order]
+ self.targets[i] = targets[order]
+
+ def __iter__(self):
+ offsets = [ 0 for _ in self.edges ]
+
+ while True:
+ choice = [ i for i in range(len(offsets)) \
+ if offsets[i] < len(self.edges[i]) ]
+ if len(choice) == 0:
+ break
+ fam_idx = torch.randint(len(choice), (1,), generator=self.generator).item()
+ ofs = offsets[fam_idx]
+ edges = self.edges[fam_idx][ofs:ofs + self.batch_size]
+ targets = self.targets[fam_idx][ofs:ofs + self.batch_size]
+ offsets[fam_idx] += self.batch_size
+ yield (fam_idx, edges, targets)
+
+
+class FastLoop(object):
+ def __init__(
+ self,
+ model: FastModel,
+ lr: float = 0.001,
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
+ torch.nn.functional.binary_cross_entropy_with_logits,
+ batch_size: int = 100,
+ shuffle: bool = True,
+ generator: torch.Generator = None) -> None:
+
+ self._check_params(model, loss, generator)
+
+ self.model = model
+ self.lr = float(lr)
+ self.loss = loss
+ self.batch_size = int(batch_size)
+ self.shuffle = bool(shuffle)
+ self.generator = generator or torch.default_generator
+
+ self.opt = None
+
+ self.build()
+
+ def _check_params(self, model, loss, generator):
+ if not isinstance(model, FastModel):
+ raise TypeError('model must be an instance of FastModel')
+
+ if not isinstance(loss, FunctionType):
+ raise TypeError('loss must be a function')
+
+ if generator is not None and not isinstance(generator, torch.Generator):
+ raise TypeError('generator must be an instance of torch.Generator')
+
+ def build(self) -> None:
+ opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+ self.opt = opt
+
+ def run_epoch(self):
+ prep_d = self.model.prep_d
+
+ batcher = FastBatcher(self.model.prep_d, batch_size=self.batch_size,
+ shuffle = self.shuffle, generator=self.generator)
+ # pred = self.model(None)
+ # n = len(list(iter(batch)))
+ loss_sum = 0
+ for fam_idx, edges, targets in batcher:
+ self.opt.zero_grad()
+ pred = self.model(None)
+
+ # process pred, get input and targets
+ input = pred[fam_idx][edges[:, 0], edges[:, 1]]
+
+ loss = self.loss(input, targets)
+ loss.backward()
+ self.opt.step()
+ loss_sum += loss.detach().cpu().item()
+ return loss_sum
+
+
+ def train(self, max_epochs):
+ best_loss = None
+ best_epoch = None
+ for i in range(max_epochs):
+ loss = self.run_epoch()
+ if best_loss is None or loss < best_loss:
+ best_loss = loss
+ best_epoch = i
+ return loss, best_loss, best_epoch
diff --git a/src/triacontagon/deprecated/fastmodel.py b/src/triacontagon/deprecated/fastmodel.py
new file mode 100644
index 0000000..a68fe58
--- /dev/null
+++ b/src/triacontagon/deprecated/fastmodel.py
@@ -0,0 +1,79 @@
+from .fastconv import FastConvLayer
+from .bulkdec import BulkDecodeLayer
+from .input import OneHotInputLayer
+from .trainprep import PreparedData
+import torch
+import types
+from typing import List, \
+ Union, \
+ Callable
+
+
+class FastModel(torch.nn.Module):
+ def __init__(self, prep_d: PreparedData,
+ layer_dimensions: List[int] = [32, 64],
+ keep_prob: float = 1.,
+ rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
+ dec_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+ **kwargs) -> None:
+
+ super().__init__(**kwargs)
+
+ self._check_params(prep_d, layer_dimensions, rel_activation,
+ layer_activation, dec_activation)
+
+ self.prep_d = prep_d
+ self.layer_dimensions = layer_dimensions
+ self.keep_prob = float(keep_prob)
+ self.rel_activation = rel_activation
+ self.layer_activation = layer_activation
+ self.dec_activation = dec_activation
+
+ self.seq = None
+ self.build()
+
+ def build(self):
+ in_layer = OneHotInputLayer(self.prep_d)
+ last_output_dim = in_layer.output_dim
+ seq = [ in_layer ]
+
+ for dim in self.layer_dimensions:
+ conv_layer = FastConvLayer(input_dim = last_output_dim,
+ output_dim = [dim] * len(self.prep_d.node_types),
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ rel_activation = self.rel_activation,
+ layer_activation = self.layer_activation)
+ last_output_dim = conv_layer.output_dim
+ seq.append(conv_layer)
+
+ dec_layer = BulkDecodeLayer(input_dim = last_output_dim,
+ data = self.prep_d,
+ keep_prob = self.keep_prob,
+ activation = self.dec_activation)
+ seq.append(dec_layer)
+
+ seq = torch.nn.Sequential(*seq)
+ self.seq = seq
+
+ def forward(self, _):
+ return self.seq(None)
+
+ def _check_params(self, prep_d, layer_dimensions, rel_activation,
+ layer_activation, dec_activation):
+
+ if not isinstance(prep_d, PreparedData):
+ raise TypeError('prep_d must be an instanced of PreparedData')
+
+ if not isinstance(layer_dimensions, list):
+ raise TypeError('layer_dimensions must be a list')
+
+ if not isinstance(rel_activation, types.FunctionType):
+ raise TypeError('rel_activation must be a function')
+
+ if not isinstance(layer_activation, types.FunctionType):
+ raise TypeError('layer_activation must be a function')
+
+ if not isinstance(dec_activation, types.FunctionType):
+ raise TypeError('dec_activation must be a function')
diff --git a/src/triacontagon/deprecated/input.py b/src/triacontagon/deprecated/input.py
new file mode 100644
index 0000000..3bf5824
--- /dev/null
+++ b/src/triacontagon/deprecated/input.py
@@ -0,0 +1,79 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from typing import Union, \
+ List
+from .data import Data
+
+
+class InputLayer(torch.nn.Module):
+ def __init__(self, data: Data, output_dim: Union[int, List[int]] = None,
+ **kwargs) -> None:
+
+ output_dim = output_dim or \
+ list(map(lambda a: a.count, data.node_types))
+
+ if not isinstance(output_dim, list):
+ output_dim = [output_dim,] * len(data.node_types)
+
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.data = data
+
+ self.is_sparse=False
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = []
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.rand(nt.count, self.output_dim[i])
+ reps = torch.nn.Parameter(reps)
+ self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self, x) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'Icosagon input layer with output_dim: %s\n' % self.output_dim
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
+
+
+class OneHotInputLayer(torch.nn.Module):
+ def __init__(self, data: Data, **kwargs) -> None:
+ output_dim = [ a.count for a in data.node_types ]
+ super().__init__(**kwargs)
+ self.output_dim = output_dim
+ self.data = data
+
+ self.is_sparse=True
+ self.node_reps = None
+ self.build()
+
+ def build(self) -> None:
+ self.node_reps = torch.nn.ParameterList()
+ for i, nt in enumerate(self.data.node_types):
+ reps = torch.eye(nt.count).to_sparse()
+ reps = torch.nn.Parameter(reps, requires_grad=False)
+ # self.register_parameter('node_reps[%d]' % i, reps)
+ self.node_reps.append(reps)
+
+ def forward(self, x) -> List[torch.nn.Parameter]:
+ return self.node_reps
+
+ def __repr__(self) -> str:
+ s = ''
+ s += 'Icosagon one-hot input layer\n'
+ s += ' # of node types: %d\n' % len(self.data.node_types)
+ for nt in self.data.node_types:
+ s += ' - %s (%d)\n' % (nt.name, nt.count)
+ return s.strip()
diff --git a/src/triacontagon/deprecated/trainprep.py b/src/triacontagon/deprecated/trainprep.py
new file mode 100644
index 0000000..c49300a
--- /dev/null
+++ b/src/triacontagon/deprecated/trainprep.py
@@ -0,0 +1,215 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from .sampling import fixed_unigram_candidate_sampler
+import torch
+from dataclasses import dataclass, \
+ field
+from typing import Any, \
+ List, \
+ Tuple, \
+ Dict
+from .data import NodeType, \
+ RelationType, \
+ RelationTypeBase, \
+ RelationFamily, \
+ RelationFamilyBase, \
+ Data
+from collections import defaultdict
+from .normalize import norm_adj_mat_one_node_type, \
+ norm_adj_mat_two_node_types
+import numpy as np
+
+
+@dataclass
+class TrainValTest(object):
+ train: Any
+ val: Any
+ test: Any
+
+
+@dataclass
+class PreparedRelationType(RelationTypeBase):
+ edges_pos: TrainValTest
+ edges_neg: TrainValTest
+ edges_back_pos: TrainValTest
+ edges_back_neg: TrainValTest
+
+
+@dataclass
+class PreparedRelationFamily(RelationFamilyBase):
+ relation_types: List[PreparedRelationType]
+
+
+@dataclass
+class PreparedData(object):
+ node_types: List[NodeType]
+ relation_families: List[PreparedRelationFamily]
+
+
+def _empty_edge_list_tvt() -> TrainValTest:
+ return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ])
+
+
+def train_val_test_split_edges(edges: torch.Tensor,
+ ratios: TrainValTest) -> TrainValTest:
+
+ if not isinstance(edges, torch.Tensor):
+ raise ValueError('edges must be a torch.Tensor')
+
+ if len(edges.shape) != 2 or edges.shape[1] != 2:
+ raise ValueError('edges shape must be (num_edges, 2)')
+
+ if not isinstance(ratios, TrainValTest):
+ raise ValueError('ratios must be a TrainValTest')
+
+ if ratios.train + ratios.val + ratios.test != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * ratios.train)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (ratios.train + ratios.val))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+
+ return TrainValTest(edges_train, edges_val, edges_test)
+
+
+def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ if adj_mat.is_sparse:
+ adj_mat = adj_mat.coalesce()
+ degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
+ device=adj_mat.device)
+ degrees = degrees.index_add(0, adj_mat.indices()[1],
+ torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
+ device=adj_mat.device))
+ edges_pos = adj_mat.indices().transpose(0, 1)
+ else:
+ degrees = adj_mat.sum(0)
+ edges_pos = torch.nonzero(adj_mat)
+ return edges_pos, degrees
+
+
+def prepare_adj_mat(adj_mat: torch.Tensor,
+ ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
+
+ if not isinstance(adj_mat, torch.Tensor):
+ raise ValueError('adj_mat must be a torch.Tensor')
+
+ edges_pos, degrees = get_edges_and_degrees(adj_mat)
+
+ neg_neighbors = fixed_unigram_candidate_sampler(
+ edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device)
+ print(edges_pos.dtype)
+ print(neg_neighbors.dtype)
+ edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1)
+
+ edges_pos = train_val_test_split_edges(edges_pos, ratios)
+ edges_neg = train_val_test_split_edges(edges_neg, ratios)
+
+ adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
+ values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype,
+ device=adj_mat.device)
+
+ return adj_mat_train, edges_pos, edges_neg
+
+
+def prep_rel_one_node_type(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ adj_mat = r.adjacency_matrix
+ adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ print('adj_mat_train:', adj_mat_train)
+ adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
+
+ return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
+ adj_mat_train, adj_mat_back_train, edges_pos, edges_neg,
+ edges_back_pos, edges_back_neg)
+
+
+def prep_rel_two_node_types_sym(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ adj_mat = r.adjacency_matrix
+ adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
+ edges_back_pos, edges_back_neg = \
+ _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ return PreparedRelationType(r.name, r.node_type_row,
+ r.node_type_column,
+ norm_adj_mat_two_node_types(adj_mat_train),
+ norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
+ edges_pos, edges_neg, edges_back_pos, edges_back_neg)
+
+
+def prep_rel_two_node_types_asym(r: RelationType,
+ ratios: TrainValTest) -> PreparedRelationType:
+
+ if r.adjacency_matrix is not None:
+ adj_mat_train, edges_pos, edges_neg =\
+ prepare_adj_mat(r.adjacency_matrix, ratios)
+ else:
+ adj_mat_train, edges_pos, edges_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ if r.adjacency_matrix_backward is not None:
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ prepare_adj_mat(r.adjacency_matrix_backward, ratios)
+ else:
+ adj_mat_back_train, edges_back_pos, edges_back_neg = \
+ None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
+
+ return PreparedRelationType(r.name, r.node_type_row,
+ r.node_type_column,
+ norm_adj_mat_two_node_types(adj_mat_train),
+ norm_adj_mat_two_node_types(adj_mat_back_train),
+ edges_pos, edges_neg, edges_back_pos, edges_back_neg)
+
+
+def prepare_relation_type(r: RelationType,
+ ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType:
+
+ if not isinstance(r, RelationType):
+ raise ValueError('r must be a RelationType')
+
+ if not isinstance(ratios, TrainValTest):
+ raise ValueError('ratios must be a TrainValTest')
+
+ if r.node_type_row == r.node_type_column:
+ return prep_rel_one_node_type(r, ratios)
+ elif is_symmetric:
+ return prep_rel_two_node_types_sym(r, ratios)
+ else:
+ return prep_rel_two_node_types_asym(r, ratios)
+
+
+def prepare_relation_family(fam: RelationFamily,
+ ratios: TrainValTest) -> PreparedRelationFamily:
+
+ relation_types = []
+
+ for r in fam.relation_types:
+ relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric))
+
+ return PreparedRelationFamily(fam.data, fam.name,
+ fam.node_type_row, fam.node_type_column,
+ fam.is_symmetric, fam.decoder_class,
+ relation_types)
+
+
+def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
+ if not isinstance(data, Data):
+ raise ValueError('data must be of class Data')
+
+ relation_families = [ prepare_relation_family(fam, ratios) \
+ for fam in data.relation_families ]
+
+ return PreparedData(data.node_types, relation_families)
diff --git a/src/triacontagon/dropout.py b/src/triacontagon/dropout.py
new file mode 100644
index 0000000..e9dde92
--- /dev/null
+++ b/src/triacontagon/dropout.py
@@ -0,0 +1,44 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+from .normalize import _sparse_coo_tensor
+
+
+def dropout_sparse(x, keep_prob):
+ x = x.coalesce()
+ i = x._indices()
+ v = x._values()
+ size = x.size()
+
+ n = keep_prob + torch.rand(len(v))
+ n = torch.floor(n).to(torch.bool)
+ i = i[:,n]
+ v = v[n]
+ x = _sparse_coo_tensor(i, v, size=size)
+
+ return x * (1./keep_prob)
+
+
+def dropout_dense(x, keep_prob):
+ # print('dropout_dense()')
+ x = x.clone()
+ i = torch.nonzero(x, as_tuple=False)
+
+ n = keep_prob + torch.rand(len(i))
+ n = (1. - torch.floor(n)).to(torch.bool)
+ x[i[n, 0], i[n, 1]] = 0.
+
+ return x * (1./keep_prob)
+
+
+def dropout(x, keep_prob):
+ if keep_prob == 1:
+ return x
+ if x.is_sparse:
+ return dropout_sparse(x, keep_prob)
+ else:
+ return dropout_dense(x, keep_prob)
diff --git a/src/triacontagon/loop.py b/src/triacontagon/loop.py
new file mode 100644
index 0000000..870c07e
--- /dev/null
+++ b/src/triacontagon/loop.py
@@ -0,0 +1,109 @@
+from .model import Model, \
+ TrainingBatch
+from .batch import DualBatcher
+from .sampling import negative_sample_data
+from .data import Data
+import torch
+from typing import List, \
+ Callable
+
+
+def _merge_pos_neg_batches(pos_batch, neg_batch):
+ assert len(pos_batch.edges) == len(neg_batch.edges)
+ assert pos_batch.vertex_type_row == neg_batch.vertex_type_row
+ assert pos_batch.vertex_type_column == neg_batch.vertex_type_column
+ assert pos_batch.relation_type_index == neg_batch.relation_type_index
+
+ batch = TrainingBatch(pos_batch.vertex_type_row,
+ pos_batch.vertex_type_column,
+ pos_batch.relation_type_index,
+ torch.cat([ pos_batch.edges, neg_batch.edges ]),
+ torch.cat([
+ torch.ones(len(pos_batch.edges)),
+ torch.zeros(len(neg_batch.edges))
+ ]))
+ return batch
+
+
+class TrainLoop(object):
+ def __init__(self, model: Model,
+ val_data: Data, test_data: Data,
+ initial_repr: List[torch.Tensor],
+ max_epochs: int = 50,
+ batch_size: int = 512,
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = \
+ torch.nn.functional.binary_cross_entropy_with_logits,
+ lr: float = 0.001) -> None:
+
+ assert isinstance(model, Model)
+ assert isinstance(val_data, Data)
+ assert isinstance(test_data, Data)
+ assert callable(loss)
+
+ self.model = model
+ self.test_data = test_data
+ self.initial_repr = list(initial_repr)
+ self.max_epochs = int(max_epochs)
+ self.batch_size = int(batch_size)
+ self.loss = loss
+ self.lr = float(lr)
+
+ self.pos_data = model.data
+ self.neg_data = negative_sample_data(model.data)
+
+ self.pos_val_data = val_data
+ self.neg_val_data = negative_sample_data(val_data)
+
+ self.batcher = DualBatcher(self.pos_data, self.neg_data,
+ batch_size=batch_size)
+ self.val_batcher = DualBatcher(self.pos_val_data, self.neg_val_data)
+
+ self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+
+ def run_epoch(self) -> None:
+ loss_sum = 0.
+ for pos_batch, neg_batch in self.batcher:
+ batch = _merge_pos_neg_batches(pos_batch, neg_batch)
+
+ self.opt.zero_grad()
+ last_layer_repr = self.model.convolve(self.initial_repr)
+ pred = self.model.decode(last_layer_repr, batch)
+ loss = self.loss(pred, batch.target_values)
+ loss.backward()
+ self.opt.step()
+
+ loss = loss.detach().cpu().item()
+ loss_sum += loss
+ print('loss:', loss)
+ return loss_sum
+
+ def validate_epoch(self):
+ loss_sum = 0.
+ for pos_batch, neg_batch in self.val_batcher:
+ batch = _merge_pos_neg_batches(pos_batch, neg_batch)
+ with torch.no_grad():
+ last_layer_repr = self.model.convolve(self.initial_repr, eval_mode=True)
+ pred = self.model.decode(last_layer_repr, batch, eval_mode=True)
+ loss = self.loss(pred, batch.target_values)
+ loss = loss.detach().cpu().item()
+ loss_sum += loss
+ return loss_sum
+
+ def run(self) -> None:
+ best_loss = float('inf')
+ epochs_without_improvement = 0
+ for epoch in range(self.max_epochs):
+ print('Epoch', epoch)
+ loss_sum = self.run_epoch()
+ print('train loss_sum:', loss_sum)
+ loss_sum = self.validate_epoch()
+ print('val loss_sum:', loss_sum)
+ if loss_sum >= best_loss:
+ epochs_without_improvement += 1
+ else:
+ epochs_without_improvement = 0
+ best_loss = loss_sum
+ if epochs_without_improvement == 2:
+ print('Early stopping after epoch', epoch, 'due to no improvement')
+
+ return (epoch, best_loss, loss_sum)
diff --git a/src/triacontagon/model.py b/src/triacontagon/model.py
new file mode 100644
index 0000000..09e6de7
--- /dev/null
+++ b/src/triacontagon/model.py
@@ -0,0 +1,277 @@
+from .data import Data, \
+ EdgeType
+import torch
+from dataclasses import dataclass
+from .weights import init_glorot
+import types
+from typing import List, \
+ Dict, \
+ Callable, \
+ Tuple
+from .util import _sparse_coo_tensor, \
+ _sparse_diag_cat, \
+ _mm
+from .normalize import norm_adj_mat_one_node_type, \
+ norm_adj_mat_two_node_types
+from .dropout import dropout
+
+
+@dataclass
+class TrainingBatch(object):
+ vertex_type_row: int
+ vertex_type_column: int
+ relation_type_index: int
+ edges: torch.Tensor
+ target_values: torch.Tensor
+
+
+def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
+ num_layers: int) -> List[List[EdgeType]]:
+
+ Q = [
+ ( batch.vertex_type_row, batch.edges[:, 0] ),
+ ( batch.vertex_type_column, batch.edges[:, 1] )
+ ]
+ print('Q:', Q)
+ res = []
+
+ for _ in range(num_layers):
+ R = []
+ required_rows = [ [] for _ in range(len(data.vertex_types)) ]
+
+ for vertex_type, vertices in Q:
+ for et in data.edge_types.values():
+ if et.vertex_type_row == vertex_type:
+ required_rows[vertex_type].append(vertices)
+ indices = et.total_connectivity.indices()
+ mask = torch.zeros(et.total_connectivity.shape[0])
+ mask[vertices] = 1
+ mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
+ R.append((et.vertex_type_column,
+ indices[1, mask]))
+ else:
+ pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
+
+ required_rows = [ torch.unique(torch.cat(x)) \
+ if len(x) > 0 \
+ else None \
+ for x in required_rows ]
+
+ res.append(required_rows)
+ Q = R
+
+ return res
+
+
+class Model(torch.nn.Module):
+ def __init__(self, data: Data, layer_dimensions: List[int],
+ keep_prob: float,
+ conv_activation: Callable[[torch.Tensor], torch.Tensor],
+ dec_activation: Callable[[torch.Tensor], torch.Tensor],
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ if not isinstance(data, Data):
+ raise TypeError('data must be an instance of Data')
+
+ if not callable(conv_activation):
+ raise TypeError('conv_activation must be callable')
+
+ if not callable(dec_activation):
+ raise TypeError('dec_activation must be callable')
+
+ self.data = data
+ self.layer_dimensions = list(layer_dimensions)
+ self.keep_prob = float(keep_prob)
+ self.conv_activation = conv_activation
+ self.dec_activation = dec_activation
+
+ self.adj_matrices = None
+ self.conv_weights = None
+ self.dec_weights = None
+ self.build()
+
+
+ def build(self) -> None:
+ self.adj_matrices = torch.nn.ParameterDict()
+ for _, et in self.data.edge_types.items():
+ adj_matrices = [
+ norm_adj_mat_one_node_type(x) \
+ if et.vertex_type_row == et.vertex_type_column \
+ else norm_adj_mat_two_node_types(x) \
+ for x in et.adjacency_matrices
+ ]
+ adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
+ print('adj_matrices:', adj_matrices)
+ self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
+ torch.nn.Parameter(adj_matrices, requires_grad=False)
+
+ self.conv_weights = torch.nn.ParameterDict()
+ for i in range(len(self.layer_dimensions) - 1):
+ in_dimension = self.layer_dimensions[i]
+ out_dimension = self.layer_dimensions[i + 1]
+
+ for _, et in self.data.edge_types.items():
+ weights = [ init_glorot(in_dimension, out_dimension) \
+ for _ in range(len(et.adjacency_matrices)) ]
+ weights = torch.cat(weights, dim=1)
+ self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
+ torch.nn.Parameter(weights)
+
+ self.dec_weights = torch.nn.ParameterDict()
+ for _, et in self.data.edge_types.items():
+ global_interaction, local_variation = \
+ et.decoder_factory(self.layer_dimensions[-1],
+ len(et.adjacency_matrices))
+ self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
+ torch.nn.Parameter(global_interaction)
+ for i in range(len(local_variation)):
+ self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
+ torch.nn.Parameter(local_variation[i])
+
+
+ def convolve(self, in_layer_repr: List[torch.Tensor], eval_mode=False) -> \
+ List[torch.Tensor]:
+
+ cur_layer_repr = in_layer_repr
+
+ for i in range(len(self.layer_dimensions) - 1):
+ next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
+
+ for _, et in self.data.edge_types.items():
+ vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
+ adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
+ conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
+
+ num_relation_types = len(et.adjacency_matrices)
+ x = cur_layer_repr[vt_col]
+ if self.keep_prob != 1 and not eval_mode:
+ x = dropout(x, self.keep_prob)
+
+ # print('a, Layer:', i, 'x.shape:', x.shape)
+
+ x = _mm(x, conv_weights)
+ x = torch.split(x,
+ x.shape[1] // num_relation_types,
+ dim=1)
+ x = torch.cat(x)
+ x = _mm(adj_matrices, x)
+ x = x.view(num_relation_types,
+ self.data.vertex_types[vt_row].count,
+ self.layer_dimensions[i + 1])
+
+ # print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
+
+ x = x.sum(dim=0)
+ x = torch.nn.functional.normalize(x, p=2, dim=1)
+ # x = self.rel_activation(x)
+ # print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
+
+ next_layer_repr[vt_row].append(x)
+
+ next_layer_repr = [ self.conv_activation(sum(x)) \
+ for x in next_layer_repr ]
+
+ cur_layer_repr = next_layer_repr
+ return cur_layer_repr
+
+ def decode(self, last_layer_repr: List[torch.Tensor],
+ batch: TrainingBatch, eval_mode=False) -> torch.Tensor:
+
+ vt_row = batch.vertex_type_row
+ vt_col = batch.vertex_type_column
+ rel_idx = batch.relation_type_index
+ global_interaction = \
+ self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
+ local_variation = \
+ self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
+
+ in_row = last_layer_repr[vt_row]
+ in_col = last_layer_repr[vt_col]
+
+ if in_row.is_sparse or in_col.is_sparse:
+ raise ValueError('Inputs to Model.decode() must be dense')
+
+ in_row = in_row[batch.edges[:, 0]]
+ in_col = in_col[batch.edges[:, 1]]
+
+ if self.keep_prob != 1 and not eval_mode:
+ in_row = dropout(in_row, self.keep_prob)
+ in_col = dropout(in_col, self.keep_prob)
+
+ # in_row = in_row.to_dense()
+ # in_col = in_col.to_dense()
+
+ print('in_row.is_sparse:', in_row.is_sparse)
+ print('in_col.is_sparse:', in_col.is_sparse)
+
+ x = torch.mm(in_row, local_variation)
+ x = torch.mm(x, global_interaction)
+ x = torch.mm(x, local_variation)
+ x = torch.bmm(x.view(x.shape[0], 1, x.shape[1]),
+ in_col.view(in_col.shape[0], in_col.shape[1], 1))
+ x = torch.flatten(x)
+
+ x = self.dec_activation(x)
+
+ return x
+
+
+ def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
+ edges = []
+ cur_edges = batch.edges
+ for _ in range(len(self.layer_dimensions) - 1):
+ edges.append(cur_edges)
+ key = (batch.vertex_type_row, batch.vertex_type_column)
+ tot_conn = self.data.relation_types[key].total_connectivity
+ cur_edges = _edges_for_rows(tot_conn, cur_edges[:, 1])
+
+
+ def temporary_adjacency_matrix(self, adjacency_matrix: torch.Tensor,
+ batch: TrainingBatch, total_connectivity: torch.Tensor) -> torch.Tensor:
+
+ col = batch.vertex_type_column
+ rows = batch.edges[:, 0]
+
+ columns = batch.edges[:, 1].sum(dim=0).flatten()
+ columns = torch.nonzero(columns)
+
+ for i in range(len(self.layer_dimensions) - 1):
+ pass # columns =
+ # TODO: finish
+
+ return None
+
+
+ def temporary_adjacency_matrices(self, batch: TrainingBatch) -> Dict[Tuple[int, int], List[List[torch.Tensor]]]:
+
+ col = batch.vertex_type_column
+ batch.edges[:, 1]
+
+ res = {}
+
+ for _, et in self.data.edge_types.items():
+ sum_nonzero = _nonzero_sum(et.adjacency_matrices)
+ res[et.vertex_type_row, et.vertex_type_column] = \
+ [ self.temporary_adjacency_matrix(adj_mat, batch,
+ et.total_connectivity) \
+ for adj_mat in et.adjacency_matrices ]
+
+ return res
+
+ def forward(self, initial_repr: List[torch.Tensor],
+ batch: TrainingBatch) -> torch.Tensor:
+
+ if not isinstance(initial_repr, list):
+ raise TypeError('initial_repr must be a list')
+
+ if len(initial_repr) != len(self.data.vertex_types):
+ raise ValueError('initial_repr must contain representations for all vertex types')
+
+ if not isinstance(batch, TrainingBatch):
+ raise TypeError('batch must be an instance of TrainingBatch')
+
+ adj_matrices = self.temporary_adjacency_matrices(batch)
+
+ row_vertices = initial_repr[batch.vertex_type_row]
+ column_vertices = initial_repr[batch.vertex_type_column]
diff --git a/src/triacontagon/normalize.py b/src/triacontagon/normalize.py
new file mode 100644
index 0000000..e13fb05
--- /dev/null
+++ b/src/triacontagon/normalize.py
@@ -0,0 +1,145 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import numpy as np
+import scipy.sparse as sp
+import torch
+
+
+def _check_tensor(adj_mat):
+ if not isinstance(adj_mat, torch.Tensor):
+ raise ValueError('adj_mat must be a torch.Tensor')
+
+
+def _check_sparse(adj_mat):
+ if not adj_mat.is_sparse:
+ raise ValueError('adj_mat must be sparse')
+
+
+def _check_dense(adj_mat):
+ if adj_mat.is_sparse:
+ raise ValueError('adj_mat must be dense')
+
+
+def _check_square(adj_mat):
+ if len(adj_mat.shape) != 2 or \
+ adj_mat.shape[0] != adj_mat.shape[1]:
+ raise ValueError('adj_mat must be a square matrix')
+
+
+def _check_2d(adj_mat):
+ if len(adj_mat.shape) != 2:
+ raise ValueError('adj_mat must be a square matrix')
+
+
+def _sparse_coo_tensor(indices, values, size):
+ ctor = { torch.float32: torch.sparse.FloatTensor,
+ torch.float32: torch.sparse.DoubleTensor,
+ torch.uint8: torch.sparse.ByteTensor,
+ torch.long: torch.sparse.LongTensor,
+ torch.int: torch.sparse.IntTensor,
+ torch.short: torch.sparse.ShortTensor,
+ torch.bool: torch.sparse.ByteTensor }[values.dtype]
+ return ctor(indices, values, size)
+
+
+def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = adj_mat.coalesce()
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+
+ eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype,
+ device=adj_mat.device).view(1, -1)
+ eye_indices = torch.cat((eye_indices, eye_indices), 0)
+ eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype,
+ device=adj_mat.device)
+
+ indices = torch.cat((indices, eye_indices), 1)
+ values = torch.cat((values, eye_values), 0)
+
+ adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = add_eye_sparse(adj_mat)
+ adj_mat = norm_adj_mat_two_node_types_sparse(adj_mat)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_dense(adj_mat)
+ _check_square(adj_mat)
+
+ adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype,
+ device=adj_mat.device)
+ adj_mat = norm_adj_mat_two_node_types_dense(adj_mat)
+
+ return adj_mat
+
+
+def norm_adj_mat_one_node_type(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_square(adj_mat)
+
+ if adj_mat.is_sparse:
+ return norm_adj_mat_one_node_type_sparse(adj_mat)
+ else:
+ return norm_adj_mat_one_node_type_dense(adj_mat)
+
+
+def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_sparse(adj_mat)
+ _check_2d(adj_mat)
+
+ adj_mat = adj_mat.coalesce()
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+ degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device)
+ degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype))
+ degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device)
+ degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype))
+ values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]])
+ adj_mat = _sparse_coo_tensor(indices, values, adj_mat.shape)
+
+ return adj_mat
+
+
+def norm_adj_mat_two_node_types_dense(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_dense(adj_mat)
+ _check_2d(adj_mat)
+
+ degrees_row = adj_mat.sum(1).view(-1, 1).to(torch.float32)
+ degrees_col = adj_mat.sum(0).view(1, -1).to(torch.float32)
+ degrees_row = torch.sqrt(degrees_row)
+ degrees_col = torch.sqrt(degrees_col)
+ adj_mat = adj_mat.to(degrees_row.dtype) / degrees_row
+ adj_mat = adj_mat / degrees_col
+
+ return adj_mat
+
+
+def norm_adj_mat_two_node_types(adj_mat: torch.Tensor) -> torch.Tensor:
+ _check_tensor(adj_mat)
+ _check_2d(adj_mat)
+
+ if adj_mat.is_sparse:
+ return norm_adj_mat_two_node_types_sparse(adj_mat)
+ else:
+ return norm_adj_mat_two_node_types_dense(adj_mat)
diff --git a/src/triacontagon/sampling.py b/src/triacontagon/sampling.py
new file mode 100644
index 0000000..73d7cf2
--- /dev/null
+++ b/src/triacontagon/sampling.py
@@ -0,0 +1,341 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import numpy as np
+import torch
+import torch.utils.data
+from typing import List, \
+ Union, \
+ Tuple
+from .data import Data, \
+ EdgeType
+from .cumcount import cumcount
+import time
+import multiprocessing
+import multiprocessing.pool
+from itertools import product, \
+ repeat
+from functools import reduce
+
+
+def fixed_unigram_candidate_sampler(
+ true_classes: torch.Tensor,
+ num_repeats: torch.Tensor,
+ unigrams: torch.Tensor,
+ distortion: float = 1.) -> torch.Tensor:
+
+ assert isinstance(true_classes, torch.Tensor)
+ assert isinstance(num_repeats, torch.Tensor)
+ assert isinstance(unigrams, torch.Tensor)
+ distortion = float(distortion)
+
+ if len(true_classes.shape) != 2:
+ raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
+
+ if len(num_repeats.shape) != 1:
+ raise ValueError('num_repeats must be 1D')
+
+ if torch.any((unigrams > 0).sum() - \
+ (true_classes >= 0).sum(dim=1) < \
+ num_repeats):
+ raise ValueError('Not enough classes to choose from')
+
+ true_class_count = true_classes.shape[1] - (true_classes == -1).sum(dim=1)
+ true_classes = torch.cat([
+ true_classes,
+ torch.full(( len(true_classes), torch.max(num_repeats) ), -1,
+ dtype=true_classes.dtype)
+ ], dim=1)
+
+ indices = torch.repeat_interleave(torch.arange(len(true_classes)), num_repeats)
+ indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
+ indices.view(-1, 1) ], dim=1)
+
+ result = torch.zeros(len(indices), dtype=torch.long)
+
+ while len(indices) > 0:
+ print(len(indices))
+
+ candidates = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
+ candidates = torch.tensor(list(candidates)).view(-1, 1)
+
+ inner_order = torch.argsort(candidates[:, 0])
+ indices_np = indices[inner_order].detach().cpu().numpy()
+ outer_order = np.argsort(indices_np[:, 1], kind='stable')
+ outer_order = torch.tensor(outer_order, device=inner_order.device)
+
+ candidates = candidates[inner_order][outer_order]
+ indices = indices[inner_order][outer_order]
+
+ mask = (true_classes[indices[:, 1]] == candidates).sum(dim=1).to(torch.bool)
+
+ # can_cum = cumcount(candidates[:, 0])
+ can_diff = torch.cat([ torch.tensor([1]), candidates[1:, 0] - candidates[:-1, 0] ])
+ ind_cum = cumcount(indices[:, 1])
+ repeated = (can_diff == 0) & (ind_cum > 0)
+ # TODO: this is wrong, still requires work
+
+ mask = mask | repeated
+
+ updated = indices[~mask]
+ if len(updated) > 0:
+ ofs = true_class_count[updated[:, 1]] + \
+ cumcount(updated[:, 1])
+ true_classes[updated[:, 1], ofs] = candidates[~mask].transpose(0, 1)
+ true_class_count[updated[:, 1]] = ofs + 1
+
+ result[indices[:, 0]] = candidates.transpose(0, 1)
+ indices = indices[mask]
+
+ return result
+
+
+def fixed_unigram_candidate_sampler_slow(
+ true_classes: torch.Tensor,
+ num_repeats: torch.Tensor,
+ unigrams: torch.Tensor,
+ distortion: float = 1.) -> torch.Tensor:
+
+ assert isinstance(true_classes, torch.Tensor)
+ assert isinstance(num_repeats, torch.Tensor)
+ assert isinstance(unigrams, torch.Tensor)
+ distortion = float(distortion)
+
+ if len(true_classes.shape) != 2:
+ raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
+
+ if len(num_repeats.shape) != 1:
+ raise ValueError('num_repeats must be 1D')
+
+ if torch.any((unigrams > 0).sum() - \
+ (true_classes >= 0).sum(dim=1) < \
+ num_repeats):
+ raise ValueError('Not enough classes to choose from')
+
+ res = []
+
+ if distortion != 1.:
+ unigrams = unigrams.to(torch.float64)
+ unigrams = unigrams ** distortion
+
+ def fun(i):
+ if i and i % 100 == 0:
+ print(i)
+ if num_repeats[i] == 0:
+ return []
+ pos = torch.flatten(true_classes[i, :])
+ pos = pos[pos >= 0]
+ w = unigrams.clone().detach()
+ w[pos] = 0
+ sampler = torch.utils.data.WeightedRandomSampler(w,
+ num_repeats[i].item(), replacement=False)
+ res = list(sampler)
+ return res
+
+ with multiprocessing.pool.ThreadPool() as p:
+ res = p.map(fun, range(len(num_repeats)))
+ res = reduce(list.__add__, res, [])
+
+ return torch.tensor(res)
+
+
+def fixed_unigram_candidate_sampler_old(
+ true_classes: torch.Tensor,
+ num_repeats: torch.Tensor,
+ unigrams: torch.Tensor,
+ distortion: float = 1.) -> torch.Tensor:
+
+ if len(true_classes.shape) != 2:
+ raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
+
+ if len(num_repeats.shape) != 1:
+ raise ValueError('num_repeats must be 1D')
+
+ if torch.any((unigrams > 0).sum() - \
+ (true_classes >= 0).sum(dim=1) < \
+ num_repeats):
+ raise ValueError('Not enough classes to choose from')
+
+ num_rows = true_classes.shape[0]
+ print('true_classes.shape:', true_classes.shape)
+ # unigrams = np.array(unigrams)
+ if distortion != 1.:
+ unigrams = unigrams.to(torch.float64) ** distortion
+ print('unigrams:', unigrams)
+
+ indices = torch.arange(num_rows)
+ indices = torch.repeat_interleave(indices, num_repeats)
+ indices = torch.cat([ torch.arange(len(indices)).view(-1, 1),
+ indices.view(-1, 1) ], dim=1)
+
+ num_samples = len(indices)
+ result = torch.zeros(num_samples, dtype=torch.long)
+ print('num_rows:', num_rows, 'num_samples:', num_samples)
+
+ while len(indices) > 0:
+ print('len(indices):', len(indices))
+ print('indices:', indices)
+ sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
+ candidates = torch.tensor(list(sampler))
+ candidates = candidates.view(len(indices), 1)
+ print('candidates:', candidates)
+ print('true_classes:', true_classes[indices[:, 1], :])
+ result[indices[:, 0]] = candidates.transpose(0, 1)
+ print('result:', result)
+ mask = (candidates == true_classes[indices[:, 1], :])
+ mask = mask.sum(1).to(torch.bool)
+ # append_true_classes = torch.full(( len(true_classes), ), -1)
+ # append_true_classes[~mask] = torch.flatten(candidates)[~mask]
+ # true_classes = torch.cat([
+ # append_true_classes.view(-1, 1),
+ # true_classes
+ # ], dim=1)
+ print('mask:', mask)
+ indices = indices[mask]
+ # result[indices] = 0
+ return result
+
+
+def get_edges_and_degrees(adj_mat: torch.Tensor) -> \
+ Tuple[torch.Tensor, torch.Tensor]:
+
+ if adj_mat.is_sparse:
+ adj_mat = adj_mat.coalesce()
+ degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64,
+ device=adj_mat.device)
+ degrees = degrees.index_add(0, adj_mat.indices()[1],
+ torch.ones(adj_mat.indices().shape[1], dtype=torch.int64,
+ device=adj_mat.device))
+ edges_pos = adj_mat.indices().transpose(0, 1)
+ else:
+ degrees = adj_mat.sum(0)
+ edges_pos = torch.nonzero(adj_mat, as_tuple=False)
+ return edges_pos, degrees
+
+
+def get_true_classes(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ indices = adj_mat.indices()
+ row_count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
+ #print('indices[0]:', indices[0], count[indices[0]])
+ row_count = row_count.index_add(0, indices[0],
+ torch.ones(indices.shape[1], dtype=torch.long))
+ #print('count:', count)
+ max_true_classes = torch.max(row_count).item()
+ #print('max_true_classes:', max_true_classes)
+ true_classes = torch.full((adj_mat.shape[0], max_true_classes),
+ -1, dtype=torch.long)
+
+
+ # inv = torch.unique(indices[0], return_inverse=True)
+
+ # indices = indices.copy()
+ # true_classes[indices[0], 0] = indices[1]
+ t = time.time()
+ cc = cumcount(indices[0])
+ print('cumcount() took:', time.time() - t)
+ # cc = torch.tensor(cc)
+ t = time.time()
+ true_classes[indices[0], cc] = indices[1]
+ print('assignment took:', time.time() - t)
+
+ ''' count = torch.zeros(adj_mat.shape[0], dtype=torch.long)
+ for i in range(indices.shape[1]):
+ # print('looping...')
+ row = indices[0, i]
+ col = indices[1, i]
+ #print('row:', row, 'col:', col, 'count[row]:', count[row])
+ true_classes[row, count[row]] = col
+ count[row] += 1 '''
+
+ # t = time.time()
+ # true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
+ # print('repeat_interleave() took:', time.time() - t)
+
+ return true_classes, row_count
+
+
+def negative_sample_adj_mat(adj_mat: torch.Tensor,
+ remove_diagonal: bool=False) -> torch.Tensor:
+
+ if not isinstance(adj_mat, torch.Tensor):
+ raise ValueError('adj_mat must be a torch.Tensor, got: %s' % adj_mat.__class__.__name__)
+
+ edges_pos, degrees = get_edges_and_degrees(adj_mat)
+ degrees = degrees.to(torch.float32) + 1.0 / torch.numel(adj_mat)
+
+ true_classes, row_count = get_true_classes(adj_mat)
+ if remove_diagonal:
+ true_classes = torch.cat([ torch.arange(len(adj_mat)).view(-1, 1),
+ true_classes ], dim=1)
+ # true_classes = edges_pos[:, 1].view(-1, 1)
+ # print('true_classes:', true_classes)
+
+ neg_neighbors = fixed_unigram_candidate_sampler(
+ true_classes, row_count, degrees, 0.75).to(adj_mat.device)
+
+ print('neg_neighbors:', neg_neighbors)
+
+ pos_vertices = torch.repeat_interleave(torch.arange(len(adj_mat)),
+ row_count)
+
+ edges_neg = torch.cat([ pos_vertices.view(-1, 1),
+ neg_neighbors.view(-1, 1) ], 1)
+
+ adj_mat_neg = torch.sparse_coo_tensor(indices = edges_neg.transpose(0, 1),
+ values=torch.ones(len(edges_neg)), size=adj_mat.shape,
+ dtype=adj_mat.dtype, device=adj_mat.device)
+
+ adj_mat_neg = adj_mat_neg.coalesce()
+ indices = adj_mat_neg.indices()
+ adj_mat_neg = torch.sparse_coo_tensor(indices,
+ torch.ones(indices.shape[1]), adj_mat.shape,
+ dtype=adj_mat.dtype, device=adj_mat.device)
+
+ adj_mat_neg = adj_mat_neg.coalesce()
+
+ return adj_mat_neg
+
+
+def negative_sample_data(data: Data) -> Data:
+ new_edge_types = {}
+ res = Data(target_value=0)
+ for vt in data.vertex_types:
+ res.add_vertex_type(vt.name, vt.count)
+ for key, et in data.edge_types.items():
+ print('key:', key)
+ adjacency_matrices_neg = []
+ for adj_mat in et.adjacency_matrices:
+ remove_diagonal = True \
+ if et.vertex_type_row == et.vertex_type_column \
+ else False
+ adj_mat_neg = negative_sample_adj_mat(adj_mat, remove_diagonal)
+ adjacency_matrices_neg.append(adj_mat_neg)
+ res.add_edge_type(et.name,
+ et.vertex_type_row, et.vertex_type_column,
+ adjacency_matrices_neg, et.decoder_factory)
+ #new_et = EdgeType(et.name, et.vertex_type_row,
+ # et.vertex_type_column, adjacency_matrices_neg,
+ # et.decoder_factory, et.total_connectivity)
+ #new_edge_types[key] = new_et
+ #res = Data(data.vertex_types, new_edge_types)
+ return res
+
+
+def merge_data(pos_data: Data, neg_data: Data) -> Data:
+ assert isinstance(pos_data, Data)
+ assert isinstance(neg_data, Data)
+
+ res = PosNegData()
+
+ for vt in pos_data.vertex_types:
+ res.add_vertex_type(vt.name, vt.count)
+
+ for key, pos_et in pos_data.edge_types.items():
+ neg_et = neg_data.edge_types[key]
+ res.add_edge_type(pos_et.name,
+ pos_et.vertex_type_row, pos_et.vertex_type_column,
+ pos_et.adjacency_matrices, neg_et.adjacency_matrices,
+ pos_et.decoder_factory)
diff --git a/src/triacontagon/split.py b/src/triacontagon/split.py
new file mode 100644
index 0000000..c105e4e
--- /dev/null
+++ b/src/triacontagon/split.py
@@ -0,0 +1,81 @@
+from .data import Data, \
+ EdgeType
+from typing import Tuple, \
+ List
+from .util import _sparse_coo_tensor
+import torch
+
+
+def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
+ ratios = list(ratios)
+ if sum(ratios) != 1:
+ raise ValueError('Sum of ratios must be 1')
+
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+
+ order = torch.randperm(indices.shape[1])
+
+ indices = indices[:, order]
+ values = values[order]
+
+ ofs = 0
+ res = []
+ for r in ratios:
+ # cnt = r * len(values)
+
+ beg = int(ofs * len(values))
+ end = int((ofs + r) * len(values))
+ ofs += r
+
+ ind = indices[:, beg:end]
+ val = values[beg:end]
+ res.append(_sparse_coo_tensor(ind, val, adj_mat.shape).coalesce())
+ # ofs += cnt
+
+ return res
+
+
+def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
+ ratios = list(ratios)
+ if sum(ratios) != 1:
+ raise ValueError('Sum of ratios must be 1')
+
+ res = [ split_adj_mat(adj_mat, ratios) \
+ for adj_mat in et.adjacency_matrices ]
+
+ res = [ EdgeType(et.name,
+ et.vertex_type_row,
+ et.vertex_type_column,
+ [ a[i] for a in res ],
+ et.decoder_factory,
+ None ) for i in range(len(ratios)) ]
+
+ return res
+
+
+def split_data(data: Data,
+ ratios: List[float]):
+
+ if not isinstance(data, Data):
+ raise TypeError('data must be an instance of Data')
+
+ ratios = list(ratios)
+
+ if sum(ratios) != 1:
+ raise ValueError('ratios must sum to 1')
+
+ res = [ {} for _ in range(len(ratios)) ]
+
+ for key, et in data.edge_types.items():
+ for i, new_et in enumerate(split_edge_type(et, ratios)):
+ res[i][key] = new_et
+
+ res_1 = []
+ for new_edge_types in res:
+ d = Data()
+ d.vertex_types = data.vertex_types
+ d.edge_types = new_edge_types
+ res_1.append(d)
+
+ return res_1
diff --git a/src/triacontagon/util.py b/src/triacontagon/util.py
new file mode 100644
index 0000000..27e8524
--- /dev/null
+++ b/src/triacontagon/util.py
@@ -0,0 +1,271 @@
+import torch
+from typing import List, \
+ Set
+import time
+
+
+def _diag(x: torch.Tensor, make_sparse: bool=False):
+ if len(x.shape) < 1 or len(x.shape) > 2:
+ raise ValueError('Matrix or vector expected')
+
+ if not x.is_sparse and not make_sparse:
+ return torch.diag(x)
+
+ if len(x.shape) == 1:
+ indices = torch.arange(len(x)).view(1, -1)
+ indices = torch.cat([ indices, indices ])
+ return _sparse_coo_tensor(indices, x.to_dense(), (len(x),) * 2)
+
+ values = x.values()
+ indices = x.indices()
+ mask = torch.nonzero(indices[0] == indices[1], as_tuple=True)[0]
+ indices = torch.flatten(indices[0, mask])
+ order = torch.argsort(indices)
+ values = values[mask][order]
+ res = torch.zeros(min(x.shape[0], x.shape[1]), dtype=values.dtype)
+ res[indices] = values
+ return res
+
+
+def _equal(x: torch.Tensor, y: torch.Tensor):
+ if x.is_sparse ^ y.is_sparse:
+ raise ValueError('Cannot mix sparse and dense tensors')
+
+ if not x.is_sparse:
+ return (x == y)
+
+ return ((x - y).coalesce().values() == 0)
+
+
+def _sparse_coo_tensor(indices, values, size):
+ ctor = { torch.float32: torch.sparse.FloatTensor,
+ torch.float32: torch.sparse.DoubleTensor,
+ torch.uint8: torch.sparse.ByteTensor,
+ torch.long: torch.sparse.LongTensor,
+ torch.int: torch.sparse.IntTensor,
+ torch.short: torch.sparse.ShortTensor,
+ torch.bool: torch.sparse.ByteTensor }[values.dtype]
+ return ctor(indices, values, size)
+
+
+def _nonzero_sum(adjacency_matrices: List[torch.Tensor]):
+ if len(adjacency_matrices) == 0:
+ raise ValueError('adjacency_matrices must be non-empty')
+
+ if not all([x.is_sparse for x in adjacency_matrices]):
+ raise ValueError('All adjacency matrices must be sparse')
+
+ indices = [ x.indices() for x in adjacency_matrices ]
+ indices = torch.cat(indices, dim=1)
+
+ values = torch.ones(indices.shape[1])
+ res = _sparse_coo_tensor(indices, values, adjacency_matrices[0].shape)
+ res = res.coalesce()
+
+ indices = res.indices()
+ res = _sparse_coo_tensor(indices,
+ torch.ones(indices.shape[1], dtype=torch.uint8),
+ adjacency_matrices[0].shape)
+ res = res.coalesce()
+
+ return res
+
+
+def _clear_adjacency_matrix_except_rows(adjacency_matrix: torch.Tensor,
+ rows: torch.Tensor, row_vertex_count: int, num_relation_types: int) -> torch.Tensor:
+
+ if not adjacency_matrix.is_sparse:
+ raise ValueError('adjacency_matrix must be sparse')
+
+ if not adjacency_matrix.shape[0] == row_vertex_count * num_relation_types:
+ raise ValueError('adjacency_matrix must have as many rows as row vertex count times number of relation types')
+
+ t = time.time()
+ rows = [ rows + row_vertex_count * i \
+ for i in range(num_relation_types) ]
+ print('rows took:', time.time() - t)
+ t = time.time()
+ rows = torch.cat(rows)
+ print('cat took:', time.time() - t)
+ # print('rows:', rows)
+ # rows = set(rows.tolist())
+ # print('rows:', rows)
+
+ t = time.time()
+ adj_mat = adjacency_matrix.coalesce()
+ indices = adj_mat.indices()
+ values = adj_mat.values()
+ print('indices[0]:', indices[0])
+ # print('indices[0][1]:', indices[0][1], indices[0][1] in rows)
+
+ lookup = torch.zeros(row_vertex_count * num_relation_types,
+ dtype=torch.uint8, device=adj_mat.device)
+ lookup[rows] = 1
+ values = values * lookup[indices[0]]
+ mask = torch.nonzero(values > 0, as_tuple=True)[0]
+ indices = indices[:, mask]
+ values = values[mask]
+
+ res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
+ # res = res.coalesce()
+ print('res:', res)
+ print('"index_select()" took:', time.time() - t)
+
+ return res
+
+ selection = torch.tensor([ (idx.item() in rows) for idx in indices[0] ])
+ # print('selection:', selection)
+ selection = torch.nonzero(selection, as_tuple=True)[0]
+ # print('selection:', selection)
+ indices = indices[:, selection]
+ values = values[selection]
+ print('"index_select()" took:', time.time() - t)
+
+ t = time.time()
+ res = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
+ print('_sparse_coo_tensor() took:', time.time() - t)
+
+ return res
+
+ # t = time.time()
+ # adj_mat = torch.index_select(adjacency_matrix, 0, rows)
+ # print('index_select took:', time.time() - t)
+
+ t = time.time()
+ adj_mat = adj_mat.coalesce()
+ print('coalesce() took:', time.time() - t)
+ indices = adj_mat.indices()
+
+ # print('indices:', indices)
+
+ values = adj_mat.values()
+ t = time.time()
+ indices[0] = rows[indices[0]]
+ print('Lookup took:', time.time() - t)
+
+ t = time.time()
+ adj_mat = _sparse_coo_tensor(indices, values, adjacency_matrix.shape)
+ print('_sparse_coo_tensor() took:', time.time() - t)
+
+ return adj_mat
+
+
+def _sparse_diag_cat(matrices: List[torch.Tensor]):
+ if len(matrices) == 0:
+ raise ValueError('The list of matrices must be non-empty')
+
+ if not all(m.is_sparse for m in matrices):
+ raise ValueError('All matrices must be sparse')
+
+ if not all(len(m.shape) == 2 for m in matrices):
+ raise ValueError('All matrices must be 2D')
+
+ indices = []
+ values = []
+ row_offset = 0
+ col_offset = 0
+
+ for m in matrices:
+ ind = m._indices().clone()
+ ind[0] += row_offset
+ ind[1] += col_offset
+ indices.append(ind)
+ values.append(m._values())
+ row_offset += m.shape[0]
+ col_offset += m.shape[1]
+
+ indices = torch.cat(indices, dim=1)
+ values = torch.cat(values)
+
+ return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
+
+
+def _cat(matrices: List[torch.Tensor]):
+ if len(matrices) == 0:
+ raise ValueError('Empty list passed to _cat()')
+
+ n = sum(a.is_sparse for a in matrices)
+ if n != 0 and n != len(matrices):
+ raise ValueError('All matrices must have the same layout (dense or sparse)')
+
+ if not all(a.shape[1:] == matrices[0].shape[1:] for a in matrices):
+ raise ValueError('All matrices must have the same dimensions apart from dimension 0')
+
+ if not matrices[0].is_sparse:
+ return torch.cat(matrices)
+
+ total_rows = sum(a.shape[0] for a in matrices)
+ indices = []
+ values = []
+ row_offset = 0
+
+ for a in matrices:
+ ind = a._indices().clone()
+ val = a._values()
+ ind[0] += row_offset
+ ind = ind.transpose(0, 1)
+ indices.append(ind)
+ values.append(val)
+ row_offset += a.shape[0]
+
+ indices = torch.cat(indices).transpose(0, 1)
+ values = torch.cat(values)
+
+ res = _sparse_coo_tensor(indices, values, size=(row_offset, matrices[0].shape[1]))
+ return res
+
+
+def _mm(a: torch.Tensor, b: torch.Tensor):
+ if a.is_sparse:
+ return torch.sparse.mm(a, b)
+ else:
+ return torch.mm(a, b)
+
+
+def _select_rows(a: torch.Tensor, rows: torch.Tensor):
+ if not a.is_sparse:
+ return a[rows]
+
+ indices = a.indices()
+ values = a.values()
+
+ mask = torch.zeros(a.shape[0])
+ mask[rows] = 1
+ if mask.sum() != len(rows):
+ raise ValueError('Rows must be unique')
+ mask = mask[indices[0]]
+ mask = torch.nonzero(mask, as_tuple=True)[0]
+
+ new_rows[rows] = torch.arange(len(rows))
+ new_rows = new_rows[indices[0]]
+
+ indices = indices[:, mask]
+ indices[0] = new_rows
+ values = values[mask]
+
+ res = _sparse_coo_tensor(indices, values,
+ size=(len(rows), a.shape[1]))
+ return res
+
+
+def common_one_hot_encoding(vertex_type_counts: List[int], device=None) -> \
+ List[torch.Tensor]:
+
+ tot = sum(vertex_type_counts)
+ # indices = torch.cat([ torch.arange(tot).view(1, -1) ] * 2, dim=0)
+ # print('indices.shape:', indices.shape)
+ ofs = 0
+ res = []
+
+ for cnt in vertex_type_counts:
+ ind = torch.cat([
+ torch.arange(cnt).view(1, -1),
+ torch.arange(ofs, ofs+cnt).view(1, -1)
+ ])
+ val = torch.ones(cnt)
+ x = _sparse_coo_tensor(ind, val, size=(cnt, tot))
+ x = x.to(device)
+ res.append(x)
+ ofs += cnt
+
+ return res
diff --git a/src/triacontagon/weights.py b/src/triacontagon/weights.py
new file mode 100644
index 0000000..2dcb7b4
--- /dev/null
+++ b/src/triacontagon/weights.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+import torch
+import numpy as np
+
+
+def init_glorot(in_channels, out_channels, dtype=torch.float32):
+ """Create a weight variable with Glorot & Bengio (AISTATS 2010)
+ initialization.
+ """
+ init_range = np.sqrt(6.0 / (in_channels + out_channels))
+ initial = -init_range + 2 * init_range * \
+ torch.rand(( in_channels, out_channels ), dtype=dtype)
+ initial = initial.requires_grad_(True)
+ return initial
diff --git a/tests/decagon_pytorch/layer/test_layer_convolve.py b/tests/decagon_pytorch/layer/test_layer_convolve.py
new file mode 100644
index 0000000..69e3a75
--- /dev/null
+++ b/tests/decagon_pytorch/layer/test_layer_convolve.py
@@ -0,0 +1,169 @@
+from decagon_pytorch.layer import InputLayer, \
+ OneHotInputLayer, \
+ DecagonLayer
+from decagon_pytorch.data import Data
+import torch
+import pytest
+from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
+ SparseMultiDGCA, \
+ DropoutGraphConvActivation
+
+
+def _some_data():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Target', 1, 0, None)
+ d.add_relation_type('Interaction', 0, 0, None)
+ d.add_relation_type('Side Effect: Nausea', 1, 1, None)
+ d.add_relation_type('Side Effect: Infertility', 1, 1, None)
+ d.add_relation_type('Side Effect: Death', 1, 1, None)
+ return d
+
+
+def _some_data_with_interactions():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Target', 1, 0,
+ torch.rand((100, 1000), dtype=torch.float32).round())
+ d.add_relation_type('Interaction', 0, 0,
+ torch.rand((1000, 1000), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Nausea', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Infertility', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Death', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ return d
+
+
+def test_decagon_layer_01():
+ d = _some_data_with_interactions()
+ in_layer = InputLayer(d)
+ d_layer = DecagonLayer(d, in_layer, output_dim=32)
+
+
+def test_decagon_layer_02():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(d, in_layer, output_dim=32)
+ _ = d_layer() # dummy call
+
+
+def test_decagon_layer_03():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(d, in_layer, output_dim=32)
+ assert d_layer.data == d
+ assert d_layer.previous_layer == in_layer
+ assert d_layer.input_dim == [ 1000, 100 ]
+ assert not d_layer.is_sparse
+ assert d_layer.keep_prob == 1.
+ assert d_layer.rel_activation(0.5) == 0.5
+ x = torch.tensor([-1, 0, 0.5, 1])
+ assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
+ assert len(d_layer.next_layer_repr) == 2
+
+ for i in range(2):
+ assert len(d_layer.next_layer_repr[i]) == 2
+ assert isinstance(d_layer.next_layer_repr[i], list)
+ assert isinstance(d_layer.next_layer_repr[i][0], tuple)
+ assert isinstance(d_layer.next_layer_repr[i][0][0], list)
+ assert isinstance(d_layer.next_layer_repr[i][0][1], int)
+ assert all([
+ isinstance(dgca, DropoutGraphConvActivation) \
+ for dgca in d_layer.next_layer_repr[i][0][0]
+ ])
+ assert all([
+ dgca.output_dim == 32 \
+ for dgca in d_layer.next_layer_repr[i][0][0]
+ ])
+
+
+def test_decagon_layer_04():
+ # check if it is equivalent to MultiDGCA, as it should be
+
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ d.add_relation_type('Dummy Relation', 0, 0,
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(d)
+
+ multi_dgca = SparseMultiDGCA([10], 32,
+ [r.adjacency_matrix for r in d.relation_types[0, 0]],
+ keep_prob=1., activation=lambda x: x)
+
+ d_layer = DecagonLayer(d, in_layer, output_dim=32,
+ keep_prob=1., rel_activation=lambda x: x,
+ layer_activation=lambda x: x)
+
+ assert isinstance(d_layer.next_layer_repr[0][0][0][0],
+ DropoutGraphConvActivation)
+
+ weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight
+ assert isinstance(weight, torch.Tensor)
+
+ assert len(multi_dgca.sparse_dgca) == 1
+ assert isinstance(multi_dgca.sparse_dgca[0], SparseDropoutGraphConvActivation)
+
+ multi_dgca.sparse_dgca[0].sparse_graph_conv.weight = weight
+
+ out_d_layer = d_layer()
+ out_multi_dgca = multi_dgca(in_layer())
+
+ assert isinstance(out_d_layer, list)
+ assert len(out_d_layer) == 1
+
+ assert torch.all(out_d_layer[0] == out_multi_dgca)
+
+
+def test_decagon_layer_05():
+ # check if it is equivalent to MultiDGCA, as it should be
+ # this time for two relations, same edge type
+
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ d.add_relation_type('Dummy Relation 1', 0, 0,
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ d.add_relation_type('Dummy Relation 2', 0, 0,
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(d)
+
+ multi_dgca = SparseMultiDGCA([100, 100], 32,
+ [r.adjacency_matrix for r in d.relation_types[0, 0]],
+ keep_prob=1., activation=lambda x: x)
+
+ d_layer = DecagonLayer(d, in_layer, output_dim=32,
+ keep_prob=1., rel_activation=lambda x: x,
+ layer_activation=lambda x: x)
+
+ assert all([
+ isinstance(dgca, DropoutGraphConvActivation) \
+ for dgca in d_layer.next_layer_repr[0][0][0]
+ ])
+
+ weight = [ dgca.graph_conv.weight \
+ for dgca in d_layer.next_layer_repr[0][0][0] ]
+ assert all([
+ isinstance(w, torch.Tensor) \
+ for w in weight
+ ])
+
+ assert len(multi_dgca.sparse_dgca) == 2
+ for i in range(2):
+ assert isinstance(multi_dgca.sparse_dgca[i], SparseDropoutGraphConvActivation)
+
+ for i in range(2):
+ multi_dgca.sparse_dgca[i].sparse_graph_conv.weight = weight[i]
+
+ out_d_layer = d_layer()
+ x = in_layer()
+ out_multi_dgca = multi_dgca([ x[0], x[0] ])
+
+ assert isinstance(out_d_layer, list)
+ assert len(out_d_layer) == 1
+
+ assert torch.all(out_d_layer[0] == out_multi_dgca)
diff --git a/tests/decagon_pytorch/layer/test_layer_decode.py b/tests/decagon_pytorch/layer/test_layer_decode.py
new file mode 100644
index 0000000..8c2e336
--- /dev/null
+++ b/tests/decagon_pytorch/layer/test_layer_decode.py
@@ -0,0 +1,22 @@
+from decagon_pytorch.layer import OneHotInputLayer, \
+ DecagonLayer, \
+ DecodeLayer
+from decagon_pytorch.decode.cartesian import DEDICOMDecoder
+from decagon_pytorch.data import Data
+import torch
+
+
+def test_decode_layer_01():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ d.add_relation_type('Dummy Relation 1', 0, 0,
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(d, in_layer, 32)
+ last_layer_repr = d_layer()
+ dec = DecodeLayer(d, last_layer = d_layer, decoder_class = DEDICOMDecoder)
+ pred_adj_matrices = dec(last_layer_repr)
+ assert isinstance(pred_adj_matrices, dict)
+ assert len(pred_adj_matrices) == 1
+ assert isinstance(pred_adj_matrices[0, 0], list)
+ assert len(pred_adj_matrices[0, 0]) == 1
diff --git a/tests/decagon_pytorch/layer/test_layer_input.py b/tests/decagon_pytorch/layer/test_layer_input.py
new file mode 100644
index 0000000..27b84e7
--- /dev/null
+++ b/tests/decagon_pytorch/layer/test_layer_input.py
@@ -0,0 +1,110 @@
+from decagon_pytorch.layer import InputLayer, \
+ OneHotInputLayer, \
+ DecagonLayer
+from decagon_pytorch.data import Data
+import torch
+import pytest
+from decagon_pytorch.convolve import SparseDropoutGraphConvActivation, \
+ SparseMultiDGCA, \
+ DropoutGraphConvActivation
+
+
+def _some_data():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Target', 1, 0, None)
+ d.add_relation_type('Interaction', 0, 0, None)
+ d.add_relation_type('Side Effect: Nausea', 1, 1, None)
+ d.add_relation_type('Side Effect: Infertility', 1, 1, None)
+ d.add_relation_type('Side Effect: Death', 1, 1, None)
+ return d
+
+
+def _some_data_with_interactions():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Target', 1, 0,
+ torch.rand((100, 1000), dtype=torch.float32).round())
+ d.add_relation_type('Interaction', 0, 0,
+ torch.rand((1000, 1000), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Nausea', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Infertility', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ d.add_relation_type('Side Effect: Death', 1, 1,
+ torch.rand((100, 100), dtype=torch.float32).round())
+ return d
+
+
+def test_input_layer_01():
+ d = _some_data()
+ for output_dim in [32, 64, 128]:
+ layer = InputLayer(d, output_dim)
+ assert layer.output_dim[0] == output_dim
+ assert len(layer.node_reps) == 2
+ assert layer.node_reps[0].shape == (1000, output_dim)
+ assert layer.node_reps[1].shape == (100, output_dim)
+ assert layer.data == d
+
+
+def test_input_layer_02():
+ d = _some_data()
+ layer = InputLayer(d, 32)
+ res = layer()
+ assert isinstance(res[0], torch.Tensor)
+ assert isinstance(res[1], torch.Tensor)
+ assert res[0].shape == (1000, 32)
+ assert res[1].shape == (100, 32)
+ assert torch.all(res[0] == layer.node_reps[0])
+ assert torch.all(res[1] == layer.node_reps[1])
+
+
+def test_input_layer_03():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA devices on this host')
+ d = _some_data()
+ layer = InputLayer(d, 32)
+ device = torch.device('cuda:0')
+ layer = layer.to(device)
+ print(list(layer.parameters()))
+ # assert layer.device.type == 'cuda:0'
+ assert layer.node_reps[0].device == device
+ assert layer.node_reps[1].device == device
+
+
+def test_one_hot_input_layer_01():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ assert layer.output_dim == [1000, 100]
+ assert len(layer.node_reps) == 2
+ assert layer.node_reps[0].shape == (1000, 1000)
+ assert layer.node_reps[1].shape == (100, 100)
+ assert layer.data == d
+ assert layer.is_sparse
+
+
+def test_one_hot_input_layer_02():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ res = layer()
+ assert isinstance(res[0], torch.Tensor)
+ assert isinstance(res[1], torch.Tensor)
+ assert res[0].shape == (1000, 1000)
+ assert res[1].shape == (100, 100)
+ assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
+ assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
+
+
+def test_one_hot_input_layer_03():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA devices on this host')
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ device = torch.device('cuda:0')
+ layer = layer.to(device)
+ print(list(layer.parameters()))
+ # assert layer.device.type == 'cuda:0'
+ assert layer.node_reps[0].device == device
+ assert layer.node_reps[1].device == device
diff --git a/tests/decagon_pytorch/test_convolve.py b/tests/decagon_pytorch/test_convolve.py
new file mode 100644
index 0000000..8dee490
--- /dev/null
+++ b/tests/decagon_pytorch/test_convolve.py
@@ -0,0 +1,295 @@
+import decagon_pytorch.convolve
+import decagon.deep.layers
+import torch
+import tensorflow as tf
+import numpy as np
+
+
+def prepare_data():
+ np.random.seed(0)
+ latent = np.random.random((5, 10)).astype(np.float32)
+ latent[latent < .5] = 0
+ latent = np.ceil(latent)
+ adjacency_matrices = []
+ for _ in range(5):
+ adj_mat = np.random.random((len(latent),) * 2).astype(np.float32)
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = np.ceil(adj_mat)
+ adjacency_matrices.append(adj_mat)
+ print('latent:', latent)
+ print('adjacency_matrices[0]:', adjacency_matrices[0])
+ return latent, adjacency_matrices
+
+
+def dense_to_sparse_tf(x):
+ a, b = np.where(x)
+ indices = np.array([a, b]).T
+ values = x[a, b]
+ return tf.sparse.SparseTensor(indices, values, x.shape)
+
+
+def dropout_sparse_tf(x, keep_prob, num_nonzero_elems):
+ """Dropout for sparse tensors. Currently fails for very large sparse tensors (>1M elements)
+ """
+ noise_shape = [num_nonzero_elems]
+ random_tensor = keep_prob
+ random_tensor += tf.convert_to_tensor(torch.rand(noise_shape).detach().numpy())
+ # tf.convert_to_tensor(np.random.random(noise_shape))
+ # tf.random_uniform(noise_shape)
+ dropout_mask = tf.cast(tf.floor(random_tensor), dtype=tf.bool)
+ pre_out = tf.sparse_retain(x, dropout_mask)
+ return pre_out * (1./keep_prob)
+
+
+def dense_graph_conv_torch():
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ latent = torch.tensor(latent)
+ adj_mat = adjacency_matrices[0]
+ adj_mat = torch.tensor(adj_mat)
+ conv = decagon_pytorch.convolve.DenseGraphConv(10, 10,
+ adj_mat)
+ latent = conv(latent)
+ return latent
+
+
+def dense_dropout_graph_conv_activation_torch(keep_prob=1.):
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ latent = torch.tensor(latent)
+ adj_mat = adjacency_matrices[0]
+ adj_mat = torch.tensor(adj_mat)
+ conv = decagon_pytorch.convolve.DenseDropoutGraphConvActivation(10, 10,
+ adj_mat, keep_prob=keep_prob)
+ latent = conv(latent)
+ return latent
+
+
+def sparse_graph_conv_torch():
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ print('latent.dtype:', latent.dtype)
+ latent = torch.tensor(latent).to_sparse()
+ adj_mat = adjacency_matrices[0]
+ adj_mat = torch.tensor(adj_mat).to_sparse()
+ print('adj_mat.dtype:', adj_mat.dtype,
+ 'latent.dtype:', latent.dtype)
+ conv = decagon_pytorch.convolve.SparseGraphConv(10, 10,
+ adj_mat)
+ latent = conv(latent)
+ return latent
+
+
+def sparse_graph_conv_tf():
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10,
+ torch.tensor(adjacency_matrices[0]).to_sparse())
+ weight = tf.constant(conv_torch.weight.detach().numpy())
+ latent = dense_to_sparse_tf(latent)
+ adj_mat = dense_to_sparse_tf(adjacency_matrices[0])
+ latent = tf.sparse_tensor_dense_matmul(latent, weight)
+ latent = tf.sparse_tensor_dense_matmul(adj_mat, latent)
+ return latent
+
+
+def sparse_dropout_graph_conv_activation_torch(keep_prob=1.):
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ latent = torch.tensor(latent).to_sparse()
+ adj_mat = adjacency_matrices[0]
+ adj_mat = torch.tensor(adj_mat).to_sparse()
+ conv = decagon_pytorch.convolve.SparseDropoutGraphConvActivation(10, 10,
+ adj_mat, keep_prob=keep_prob)
+ latent = conv(latent)
+ return latent
+
+
+def sparse_dropout_graph_conv_activation_tf(keep_prob=1.):
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ conv_torch = decagon_pytorch.convolve.SparseGraphConv(10, 10,
+ torch.tensor(adjacency_matrices[0]).to_sparse())
+
+ weight = tf.constant(conv_torch.weight.detach().numpy())
+ nonzero_feat = np.sum(latent > 0)
+
+ latent = dense_to_sparse_tf(latent)
+ latent = dropout_sparse_tf(latent, keep_prob,
+ nonzero_feat)
+
+ adj_mat = dense_to_sparse_tf(adjacency_matrices[0])
+
+ latent = tf.sparse_tensor_dense_matmul(latent, weight)
+ latent = tf.sparse_tensor_dense_matmul(adj_mat, latent)
+
+ latent = tf.nn.relu(latent)
+
+ return latent
+
+
+def test_sparse_graph_conv():
+ latent_torch = sparse_graph_conv_torch()
+ latent_tf = sparse_graph_conv_tf()
+ assert np.all(latent_torch.detach().numpy() == latent_tf.eval(session = tf.Session()))
+
+
+def test_sparse_dropout_graph_conv_activation():
+ for i in range(11):
+ keep_prob = i/10. + np.finfo(np.float32).eps
+
+ latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob)
+ latent_tf = sparse_dropout_graph_conv_activation_tf(keep_prob)
+
+ latent_torch = latent_torch.detach().numpy()
+ latent_tf = latent_tf.eval(session = tf.Session())
+ print('latent_torch:', latent_torch)
+ print('latent_tf:', latent_tf)
+
+ assert np.all(latent_torch - latent_tf < .000001)
+
+
+def test_sparse_multi_dgca():
+ latent_torch = None
+ latent_tf = []
+
+ for i in range(11):
+ keep_prob = i/10. + np.finfo(np.float32).eps
+
+ latent_torch = sparse_dropout_graph_conv_activation_torch(keep_prob) \
+ if latent_torch is None \
+ else latent_torch + sparse_dropout_graph_conv_activation_torch(keep_prob)
+
+ latent_tf.append(sparse_dropout_graph_conv_activation_tf(keep_prob))
+
+ latent_torch = torch.nn.functional.normalize(latent_torch, p=2, dim=1)
+ latent_tf = tf.add_n(latent_tf)
+ latent_tf = tf.nn.l2_normalize(latent_tf, dim=1)
+
+ latent_torch = latent_torch.detach().numpy()
+ latent_tf = latent_tf.eval(session = tf.Session())
+
+ assert np.all(latent_torch - latent_tf < .000001)
+
+
+def test_graph_conv():
+ latent_dense = dense_graph_conv_torch()
+ latent_sparse = sparse_graph_conv_torch()
+
+ assert np.all(latent_dense.detach().numpy() == latent_sparse.detach().numpy())
+
+
+# def setup_function(fun):
+# if fun == test_dropout_graph_conv_activation or \
+# fun == test_multi_dgca:
+# print('Disabling dropout for testing...')
+# setup_function.old_dropout = decagon_pytorch.convolve.dropout, \
+# decagon_pytorch.convolve.dropout_sparse
+#
+# decagon_pytorch.convolve.dropout = lambda x, keep_prob: x
+# decagon_pytorch.convolve.dropout_sparse = lambda x, keep_prob: x
+#
+#
+# def teardown_function(fun):
+# print('Re-enabling dropout...')
+# if fun == test_dropout_graph_conv_activation or \
+# fun == test_multi_dgca:
+# decagon_pytorch.convolve.dropout, \
+# decagon_pytorch.convolve.dropout_sparse = \
+# setup_function.old_dropout
+
+
+def flexible_dropout_graph_conv_activation_torch(keep_prob=1.):
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+ latent = torch.tensor(latent).to_sparse()
+ adj_mat = adjacency_matrices[0]
+ adj_mat = torch.tensor(adj_mat).to_sparse()
+ conv = decagon_pytorch.convolve.DropoutGraphConvActivation(10, 10,
+ adj_mat, keep_prob=keep_prob)
+ latent = conv(latent)
+ return latent
+
+
+def _disable_dropout(monkeypatch):
+ monkeypatch.setattr(decagon_pytorch.convolve.dense, 'dropout',
+ lambda x, keep_prob: x)
+ monkeypatch.setattr(decagon_pytorch.convolve.sparse, 'dropout_sparse',
+ lambda x, keep_prob: x)
+ monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout',
+ lambda x, keep_prob: x)
+ monkeypatch.setattr(decagon_pytorch.convolve.universal, 'dropout_sparse',
+ lambda x, keep_prob: x)
+
+
+def test_dropout_graph_conv_activation(monkeypatch):
+ _disable_dropout(monkeypatch)
+
+ for i in range(11):
+ keep_prob = i/10.
+ if keep_prob == 0:
+ keep_prob += np.finfo(np.float32).eps
+ print('keep_prob:', keep_prob)
+
+ latent_dense = dense_dropout_graph_conv_activation_torch(keep_prob)
+ latent_dense = latent_dense.detach().numpy()
+ print('latent_dense:', latent_dense)
+
+ latent_sparse = sparse_dropout_graph_conv_activation_torch(keep_prob)
+ latent_sparse = latent_sparse.detach().numpy()
+ print('latent_sparse:', latent_sparse)
+
+ latent_flex = flexible_dropout_graph_conv_activation_torch(keep_prob)
+ latent_flex = latent_flex.detach().numpy()
+ print('latent_flex:', latent_flex)
+
+ nonzero = (latent_dense != 0) & (latent_sparse != 0)
+
+ assert np.all(latent_dense[nonzero] == latent_sparse[nonzero])
+
+ nonzero = (latent_dense != 0) & (latent_flex != 0)
+
+ assert np.all(latent_dense[nonzero] == latent_flex[nonzero])
+
+ nonzero = (latent_sparse != 0) & (latent_flex != 0)
+
+ assert np.all(latent_sparse[nonzero] == latent_flex[nonzero])
+
+
+def test_multi_dgca(monkeypatch):
+ _disable_dropout(monkeypatch)
+
+ keep_prob = .5
+
+ torch.random.manual_seed(0)
+ latent, adjacency_matrices = prepare_data()
+
+ latent_sparse = torch.tensor(latent).to_sparse()
+ latent = torch.tensor(latent)
+ assert np.all(latent_sparse.to_dense().numpy() == latent.numpy())
+
+ adjacency_matrices_sparse = [ torch.tensor(a).to_sparse() for a in adjacency_matrices ]
+ adjacency_matrices = [ torch.tensor(a) for a in adjacency_matrices ]
+
+ for i in range(len(adjacency_matrices)):
+ assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy())
+
+ torch.random.manual_seed(0)
+ multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob)
+
+ torch.random.manual_seed(0)
+ multi = decagon_pytorch.convolve.DenseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices, keep_prob=keep_prob)
+
+ print('len(adjacency_matrices):', len(adjacency_matrices))
+ print('len(multi_sparse.sparse_dgca):', len(multi_sparse.sparse_dgca))
+ print('len(multi.dgca):', len(multi.dgca))
+
+ for i in range(len(adjacency_matrices)):
+ assert np.all(multi_sparse.sparse_dgca[i].sparse_graph_conv.weight.detach().numpy() == multi.dgca[i].graph_conv.weight.detach().numpy())
+
+ # torch.random.manual_seed(0)
+ latent_sparse = multi_sparse([latent_sparse,] * len(adjacency_matrices))
+ # torch.random.manual_seed(0)
+ latent = multi([latent,] * len(adjacency_matrices))
+
+ assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())
diff --git a/tests/decagon_pytorch/test_data_list.py b/tests/decagon_pytorch/test_data_list.py
new file mode 100644
index 0000000..a9cdb51
--- /dev/null
+++ b/tests/decagon_pytorch/test_data_list.py
@@ -0,0 +1,67 @@
+from decagon_pytorch.data import AdjListData, \
+ AdjListRelationType
+import torch
+import pytest
+
+
+def _get_list():
+ lst = torch.tensor([
+ [0, 1],
+ [0, 3],
+ [0, 5],
+ [0, 7]
+ ])
+ return lst
+
+
+def test_adj_list_relation_type_01():
+ lst = _get_list()
+ rel = AdjListRelationType('Test', 0, 0, lst)
+ assert torch.all(rel.get_adjacency_list(0, 0) == lst)
+
+
+def test_adj_list_relation_type_02():
+ lst = _get_list()
+ rel = AdjListRelationType('Test', 0, 1, lst)
+ assert torch.all(rel.get_adjacency_list(0, 1) == lst)
+ lst_2 = torch.tensor([
+ [1, 0],
+ [3, 0],
+ [5, 0],
+ [7, 0]
+ ])
+ assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
+
+
+def test_adj_list_relation_type_03():
+ lst = _get_list()
+ lst_2 = torch.tensor([
+ [2, 0],
+ [4, 0],
+ [6, 0],
+ [8, 0]
+ ])
+ rel = AdjListRelationType('Test', 0, 1, lst, lst_2)
+ assert torch.all(rel.get_adjacency_list(0, 1) == lst)
+ assert torch.all(rel.get_adjacency_list(1, 0) == lst_2)
+
+
+def test_adj_list_data_01():
+ lst = _get_list()
+ d = AdjListData()
+ with pytest.raises(AssertionError):
+ d.add_relation_type('Test', 0, 1, lst)
+ d.add_node_type('Drugs', 5)
+ with pytest.raises(AssertionError):
+ d.add_relation_type('Test', 0, 0, lst)
+ d = AdjListData()
+ d.add_node_type('Drugs', 8)
+ d.add_relation_type('Test', 0, 0, lst)
+
+
+def test_adj_list_data_02():
+ lst = _get_list()
+ d = AdjListData()
+ d.add_node_type('Drugs', 10)
+ d.add_node_type('Proteins', 10)
+ d.add_relation_type('Target', 0, 1, lst)
diff --git a/tests/decagon_pytorch/test_data_matrix.py b/tests/decagon_pytorch/test_data_matrix.py
new file mode 100644
index 0000000..51426ee
--- /dev/null
+++ b/tests/decagon_pytorch/test_data_matrix.py
@@ -0,0 +1,13 @@
+from decagon_pytorch.data import Data
+
+
+def test_data():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Target', 1, 0, None)
+ d.add_relation_type('Interaction', 0, 0, None)
+ d.add_relation_type('Side Effect: Nausea', 1, 1, None)
+ d.add_relation_type('Side Effect: Infertility', 1, 1, None)
+ d.add_relation_type('Side Effect: Death', 1, 1, None)
+ print(d)
diff --git a/tests/decagon_pytorch/test_decode.py b/tests/decagon_pytorch/test_decode.py
new file mode 100644
index 0000000..de7d1a4
--- /dev/null
+++ b/tests/decagon_pytorch/test_decode.py
@@ -0,0 +1,80 @@
+import decagon_pytorch.decode.cartesian
+import decagon.deep.layers
+import numpy as np
+import tensorflow as tf
+import torch
+
+
+def _common(decoder_torch, decoder_tf):
+ inputs = np.random.rand(20, 10).astype(np.float32)
+ inputs_torch = torch.tensor(inputs)
+ inputs_tf = {
+ 0: tf.convert_to_tensor(inputs)
+ }
+ out_torch = decoder_torch(inputs_torch, inputs_torch)
+ out_tf = decoder_tf(inputs_tf)
+
+ assert len(out_torch) == len(out_tf)
+ assert len(out_tf) == 7
+
+ for i in range(len(out_torch)):
+ assert out_torch[i].shape == out_tf[i].shape
+
+ sess = tf.Session()
+ for i in range(len(out_torch)):
+ item_torch = out_torch[i].detach().numpy()
+ item_tf = out_tf[i].eval(session=sess)
+ print('item_torch:', item_torch)
+ print('item_tf:', item_tf)
+ assert np.all(item_torch - item_tf < .000001)
+ sess.close()
+
+
+def test_dedicom_decoder():
+ dedicom_torch = decagon_pytorch.decode.cartesian.DEDICOMDecoder(input_dim=10,
+ num_relation_types=7)
+ dedicom_tf = decagon.deep.layers.DEDICOMDecoder(input_dim=10, num_types=7,
+ edge_type=(0, 0))
+
+ dedicom_tf.vars['global_interaction'] = \
+ tf.convert_to_tensor(dedicom_torch.global_interaction.detach().numpy())
+ for i in range(dedicom_tf.num_types):
+ dedicom_tf.vars['local_variation_%d' % i] = \
+ tf.convert_to_tensor(dedicom_torch.local_variation[i].detach().numpy())
+
+ _common(dedicom_torch, dedicom_tf)
+
+
+def test_dist_mult_decoder():
+ distmult_torch = decagon_pytorch.decode.cartesian.DistMultDecoder(input_dim=10,
+ num_relation_types=7)
+ distmult_tf = decagon.deep.layers.DistMultDecoder(input_dim=10, num_types=7,
+ edge_type=(0, 0))
+
+ for i in range(distmult_tf.num_types):
+ distmult_tf.vars['relation_%d' % i] = \
+ tf.convert_to_tensor(distmult_torch.relation[i].detach().numpy())
+
+ _common(distmult_torch, distmult_tf)
+
+
+def test_bilinear_decoder():
+ bilinear_torch = decagon_pytorch.decode.cartesian.BilinearDecoder(input_dim=10,
+ num_relation_types=7)
+ bilinear_tf = decagon.deep.layers.BilinearDecoder(input_dim=10, num_types=7,
+ edge_type=(0, 0))
+
+ for i in range(bilinear_tf.num_types):
+ bilinear_tf.vars['relation_%d' % i] = \
+ tf.convert_to_tensor(bilinear_torch.relation[i].detach().numpy())
+
+ _common(bilinear_torch, bilinear_tf)
+
+
+def test_inner_product_decoder():
+ inner_torch = decagon_pytorch.decode.cartesian.InnerProductDecoder(input_dim=10,
+ num_relation_types=7)
+ inner_tf = decagon.deep.layers.InnerProductDecoder(input_dim=10, num_types=7,
+ edge_type=(0, 0))
+
+ _common(inner_torch, inner_tf)
diff --git a/tests/decagon_pytorch/test_decode_dims.py b/tests/decagon_pytorch/test_decode_dims.py
new file mode 100644
index 0000000..2c3a144
--- /dev/null
+++ b/tests/decagon_pytorch/test_decode_dims.py
@@ -0,0 +1,106 @@
+from decagon_pytorch.decode.cartesian import DEDICOMDecoder, \
+ DistMultDecoder, \
+ BilinearDecoder, \
+ InnerProductDecoder
+import torch
+
+
+def _common(decoder_class):
+ decoder = decoder_class(input_dim=10, num_relation_types=1)
+ inputs = torch.rand((20, 10))
+ pred = decoder(inputs, inputs)
+
+ assert isinstance(pred, list)
+ assert len(pred) == 1
+
+ assert isinstance(pred[0], torch.Tensor)
+ assert pred[0].shape == (20, 20)
+
+
+
+def test_dedicom_decoder():
+ _common(DEDICOMDecoder)
+
+
+def test_dist_mult_decoder():
+ _common(DistMultDecoder)
+
+
+def test_bilinear_decoder():
+ _common(BilinearDecoder)
+
+
+def test_inner_product_decoder():
+ _common(InnerProductDecoder)
+
+
+def test_batch_matrix_multiplication():
+ input_dim = 10
+ inputs = torch.rand((20, 10))
+
+ decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
+ out_dec = decoder(inputs, inputs)
+
+ relation = decoder.local_variation[0]
+ global_interaction = decoder.global_interaction
+ act = decoder.activation
+ relation = torch.diag(relation)
+ product1 = torch.mm(inputs, relation)
+ product2 = torch.mm(product1, global_interaction)
+ product3 = torch.mm(product2, relation)
+ rec = torch.mm(product3, torch.transpose(inputs, 0, 1))
+ rec = act(rec)
+
+ print('rec:', rec)
+ print('out_dec:', out_dec)
+
+ assert torch.all(rec == out_dec[0])
+
+
+def test_single_prediction_01():
+ input_dim = 10
+ inputs = torch.rand((20, 10))
+
+ decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
+ dec_all = decoder(inputs, inputs)
+ dec_one = decoder(inputs[0:1], inputs[0:1])
+
+ assert torch.abs(dec_all[0][0, 0] - dec_one[0][0, 0]) < 0.000001
+
+
+def test_single_prediction_02():
+ input_dim = 10
+ inputs = torch.rand((20, 10))
+
+ decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
+ dec_all = decoder(inputs, inputs)
+ dec_one = decoder(inputs[0:1], inputs[1:2])
+
+ assert torch.abs(dec_all[0][0, 1] - dec_one[0][0, 0]) < 0.000001
+ assert dec_one[0].shape == (1, 1)
+
+
+def test_pairwise_prediction():
+ n_nodes = 20
+ input_dim = 10
+ inputs_row = torch.rand((n_nodes, input_dim))
+ inputs_col = torch.rand((n_nodes, input_dim))
+
+ decoder = DEDICOMDecoder(input_dim=input_dim, num_relation_types=1)
+ dec_all = decoder(inputs_row, inputs_col)
+
+ relation = torch.diag(decoder.local_variation[0])
+ global_interaction = decoder.global_interaction
+ act = decoder.activation
+ product1 = torch.mm(inputs_row, relation)
+ product2 = torch.mm(product1, global_interaction)
+ product3 = torch.mm(product2, relation)
+ rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
+ inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
+
+ assert rec.shape == (n_nodes, 1, 1)
+
+ rec = torch.flatten(rec)
+ rec = act(rec)
+
+ assert torch.all(torch.abs(rec - torch.diag(dec_all[0])) < 0.000001)
diff --git a/tests/decagon_pytorch/test_decode_pairwise.py b/tests/decagon_pytorch/test_decode_pairwise.py
new file mode 100644
index 0000000..9b47eef
--- /dev/null
+++ b/tests/decagon_pytorch/test_decode_pairwise.py
@@ -0,0 +1,59 @@
+import decagon_pytorch.decode.cartesian as cart
+import decagon_pytorch.decode.pairwise as pair
+import torch
+
+
+def _common(cart_class, pair_class):
+ input_dim = 10
+ n_nodes = 20
+ num_relation_types = 7
+
+ inputs_row = torch.rand((n_nodes, input_dim))
+ inputs_col = torch.rand((n_nodes, input_dim))
+
+ cart_dec = cart_class(input_dim=input_dim,
+ num_relation_types=num_relation_types)
+ pair_dec = pair_class(input_dim=input_dim,
+ num_relation_types=num_relation_types)
+
+ if isinstance(cart_dec, cart.DEDICOMDecoder):
+ pair_dec.global_interaction = cart_dec.global_interaction
+ pair_dec.local_variation = cart_dec.local_variation
+ elif isinstance(cart_dec, cart.InnerProductDecoder):
+ pass
+ else:
+ pair_dec.relation = cart_dec.relation
+
+ cart_pred = cart_dec(inputs_row, inputs_col)
+ pair_pred = pair_dec(inputs_row, inputs_col)
+
+ assert isinstance(cart_pred, list)
+ assert isinstance(pair_pred, list)
+
+ assert len(cart_pred) == len(pair_pred)
+ assert len(cart_pred) == num_relation_types
+
+ for i in range(num_relation_types):
+ assert isinstance(cart_pred[i], torch.Tensor)
+ assert isinstance(pair_pred[i], torch.Tensor)
+
+ assert cart_pred[i].shape == (n_nodes, n_nodes)
+ assert pair_pred[i].shape == (n_nodes,)
+
+ assert torch.all(torch.abs(pair_pred[i] - torch.diag(cart_pred[i])) < 0.000001)
+
+
+def test_dedicom_decoder():
+ _common(cart.DEDICOMDecoder, pair.DEDICOMDecoder)
+
+
+def test_dist_mult_decoder():
+ _common(cart.DistMultDecoder, pair.DistMultDecoder)
+
+
+def test_bilinear_decoder():
+ _common(cart.BilinearDecoder, pair.BilinearDecoder)
+
+
+def test_inner_product_decoder():
+ _common(cart.InnerProductDecoder, pair.InnerProductDecoder)
diff --git a/tests/decagon_pytorch/test_dropout.py b/tests/decagon_pytorch/test_dropout.py
new file mode 100644
index 0000000..60366ec
--- /dev/null
+++ b/tests/decagon_pytorch/test_dropout.py
@@ -0,0 +1,34 @@
+from decagon_pytorch.dropout import dropout_sparse
+import torch
+import numpy as np
+
+
+def dropout_dense(a, keep_prob):
+ i = np.array(np.where(a))
+ v = a[i[0, :], i[1, :]]
+
+ # torch.random.manual_seed(0)
+ n = keep_prob + torch.rand(len(v))
+ n = torch.floor(n).to(torch.bool)
+ i = i[:, n]
+ v = v[n]
+ x = torch.sparse_coo_tensor(i, v, size=a.shape)
+
+ return x * (1./keep_prob)
+
+
+def test_dropout_sparse():
+ for i in range(11):
+ torch.random.manual_seed(i)
+ a = torch.rand((5, 10))
+ a[a < .5] = 0
+
+ keep_prob=i/10. + np.finfo(np.float32).eps
+
+ torch.random.manual_seed(i)
+ b = dropout_dense(a, keep_prob=keep_prob)
+
+ torch.random.manual_seed(i)
+ c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob)
+
+ assert np.all(np.array(b.to_dense()) == np.array(c.to_dense()))
diff --git a/tests/decagon_pytorch/test_normalize.py b/tests/decagon_pytorch/test_normalize.py
new file mode 100644
index 0000000..c7e0180
--- /dev/null
+++ b/tests/decagon_pytorch/test_normalize.py
@@ -0,0 +1,25 @@
+import decagon_pytorch.normalize
+import decagon.deep.minibatch
+import numpy as np
+
+
+def test_normalize_adjacency_matrix_square():
+ mx = np.random.rand(10, 10)
+ mx[mx < .5] = 0
+ mx = np.ceil(mx)
+ res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx)
+ res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx)
+ assert len(res_torch) == len(res_tf)
+ for i in range(len(res_torch)):
+ assert np.all(res_torch[i] == res_tf[i])
+
+
+def test_normalize_adjacency_matrix_nonsquare():
+ mx = np.random.rand(5, 10)
+ mx[mx < .5] = 0
+ mx = np.ceil(mx)
+ res_torch = decagon_pytorch.normalize.normalize_adjacency_matrix(mx)
+ res_tf = decagon.deep.minibatch.EdgeMinibatchIterator.preprocess_graph(None, mx)
+ assert len(res_torch) == len(res_tf)
+ for i in range(len(res_torch)):
+ assert np.all(res_torch[i] == res_tf[i])
diff --git a/tests/decagon_pytorch/test_sampling.py b/tests/decagon_pytorch/test_sampling.py
new file mode 100644
index 0000000..028d5b5
--- /dev/null
+++ b/tests/decagon_pytorch/test_sampling.py
@@ -0,0 +1,172 @@
+import tensorflow as tf
+import numpy as np
+from collections import defaultdict
+import torch
+import torch.utils.data
+from typing import List, \
+ Union
+import decagon_pytorch.sampling
+import scipy.stats
+
+
+def test_unigram_01():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 500
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+ true_classes = tf.convert_to_tensor(true_classes)
+
+ neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ num_true=num_true,
+ num_sampled=batch_size,
+ unique=False,
+ range_max=range_max,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ assert neg_samples.shape == (batch_size,)
+
+ for i in range(batch_size):
+ assert neg_samples[i] != true_classes[i, 0]
+
+ counts = defaultdict(int)
+ with tf.Session() as sess:
+ neg_samples = neg_samples.eval()
+ for x in neg_samples:
+ counts[x] += 1
+
+ print('counts:', counts)
+
+ assert counts[0] < counts[1] and \
+ counts[0] < counts[2] and \
+ counts[0] < counts[4] and \
+ counts[0] < counts[6]
+
+ assert counts[2] < counts[1] and \
+ counts[0] < counts[6]
+
+ assert counts[3] < counts[1] and \
+ counts[3] < counts[2] and \
+ counts[3] < counts[4] and \
+ counts[3] < counts[6]
+
+ assert counts[4] < counts[1] and \
+ counts[4] < counts[6]
+
+ assert counts[5] < counts[1] and \
+ counts[5] < counts[2] and \
+ counts[5] < counts[4] and \
+ counts[5] < counts[6]
+
+
+def test_unigram_02():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 500
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+ true_classes = torch.tensor(true_classes)
+
+ neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ num_samples=batch_size,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ assert neg_samples.shape == (batch_size,)
+
+ for i in range(batch_size):
+ assert neg_samples[i] != true_classes[i, 0]
+
+ counts = defaultdict(int)
+ for x in neg_samples:
+ counts[x] += 1
+
+ print('counts:', counts)
+
+ assert counts[0] < counts[1] and \
+ counts[0] < counts[2] and \
+ counts[0] < counts[4] and \
+ counts[0] < counts[6]
+
+ assert counts[2] < counts[1] and \
+ counts[0] < counts[6]
+
+ assert counts[3] < counts[1] and \
+ counts[3] < counts[2] and \
+ counts[3] < counts[4] and \
+ counts[3] < counts[6]
+
+ assert counts[4] < counts[1] and \
+ counts[4] < counts[6]
+
+ assert counts[5] < counts[1] and \
+ counts[5] < counts[2] and \
+ counts[5] < counts[4] and \
+ counts[5] < counts[6]
+
+
+def test_unigram_03():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 25
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+
+ true_classes_tf = tf.convert_to_tensor(true_classes)
+ true_classes_torch = torch.tensor(true_classes)
+
+ counts_tf = defaultdict(list)
+ counts_torch = defaultdict(list)
+
+ for i in range(100):
+ neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
+ true_classes=true_classes_tf,
+ num_true=num_true,
+ num_sampled=batch_size,
+ unique=False,
+ range_max=range_max,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ counts = defaultdict(int)
+ with tf.Session() as sess:
+ neg_samples = neg_samples.eval()
+ for x in neg_samples:
+ counts[x] += 1
+ for k, v in counts.items():
+ counts_tf[k].append(v)
+
+ neg_samples = decagon_pytorch.sampling.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ num_samples=batch_size,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ counts = defaultdict(int)
+ for x in neg_samples:
+ counts[x] += 1
+ for k, v in counts.items():
+ counts_torch[k].append(v)
+
+ for i in range(range_max):
+ print('counts_tf[%d]:' % i, counts_tf[i])
+ print('counts_torch[%d]:' % i, counts_torch[i])
+
+ for i in range(range_max):
+ statistic, pvalue = scipy.stats.ttest_ind(counts_tf[i], counts_torch[i])
+ assert pvalue * range_max > .05
diff --git a/tests/decagon_pytorch/test_splits.py b/tests/decagon_pytorch/test_splits.py
new file mode 100644
index 0000000..79482fb
--- /dev/null
+++ b/tests/decagon_pytorch/test_splits.py
@@ -0,0 +1,95 @@
+from decagon_pytorch.data import Data
+import torch
+from decagon_pytorch.splits import train_val_test_split_adj_mat
+import pytest
+
+
+def _gen_adj_mat(n_rows, n_cols):
+ res = torch.rand((n_rows, n_cols)).round()
+ if n_rows == n_cols:
+ res -= torch.diag(torch.diag(res))
+ a, b = torch.triu_indices(n_rows, n_cols)
+ res[a, b] = res.transpose(0, 1)[a, b]
+ return res
+
+
+def train_val_test_split_1(data, train_ratio=0.8,
+ val_ratio=0.1, test_ratio=0.1):
+
+ if train_ratio + val_ratio + test_ratio != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+
+ d_train = Data()
+ d_val = Data()
+ d_test = Data()
+
+ for (node_type_row, node_type_col), rels in data.relation_types.items():
+ for r in rels:
+ adj_train, adj_val, adj_test = train_val_test_split_adj_mat(r.adjacency_matrix)
+ d_train.add_relation_type(r.name, node_type_row, node_type_col, adj_train)
+ d_val.add_relation_type(r.name, node_type_row, node_type_col, adj_train + adj_val)
+
+
+def train_val_test_split_2(data, train_ratio, val_ratio, test_ratio):
+ if train_ratio + val_ratio + test_ratio != 1.0:
+ raise ValueError('Train, validation and test ratios must add up to 1')
+ for (node_type_row, node_type_col), rels in data.relation_types.items():
+ for r in rels:
+ adj_mat = r.adjacency_matrix
+ edges = torch.nonzero(adj_mat)
+ order = torch.randperm(len(edges))
+ edges = edges[order, :]
+ n = round(len(edges) * train_ratio)
+ edges_train = edges[:n]
+ n_1 = round(len(edges) * (train_ratio + val_ratio))
+ edges_val = edges[n:n_1]
+ edges_test = edges[n_1:]
+ if len(edges_train) * len(edges_val) * len(edges_test) == 0:
+ raise ValueError('Not enough edges to split into train/val/test sets for: ' + r.name)
+
+
+def test_train_val_test_split_adj_mat():
+ adj_mat = _gen_adj_mat(50, 100)
+ adj_mat_train, adj_mat_val, adj_mat_test = \
+ train_val_test_split_adj_mat(adj_mat, train_ratio=0.8,
+ val_ratio=0.1, test_ratio=0.1)
+
+ assert adj_mat.shape == adj_mat_train.shape == \
+ adj_mat_val.shape == adj_mat_test.shape
+
+ edges_train = torch.nonzero(adj_mat_train)
+ edges_val = torch.nonzero(adj_mat_val)
+ edges_test = torch.nonzero(adj_mat_test)
+
+ edges_train = set(map(tuple, edges_train.tolist()))
+ edges_val = set(map(tuple, edges_val.tolist()))
+ edges_test = set(map(tuple, edges_test.tolist()))
+
+ assert edges_train.intersection(edges_val) == set()
+ assert edges_train.intersection(edges_test) == set()
+ assert edges_test.intersection(edges_val) == set()
+
+ assert torch.all(adj_mat_train + adj_mat_val + adj_mat_test == adj_mat)
+
+ # assert torch.all((edges_train != edges_val).sum(1).to(torch.bool))
+ # assert torch.all((edges_train != edges_test).sum(1).to(torch.bool))
+ # assert torch.all((edges_val != edges_test).sum(1).to(torch.bool))
+
+
+@pytest.mark.skip
+def test_splits_01():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ d.add_relation_type('Interaction', 0, 0,
+ _gen_adj_mat(1000, 1000))
+ d.add_relation_type('Target', 1, 0,
+ _gen_adj_mat(100, 1000))
+ d.add_relation_type('Side Effect: Insomnia', 1, 1,
+ _gen_adj_mat(100, 100))
+ d.add_relation_type('Side Effect: Incontinence', 1, 1,
+ _gen_adj_mat(100, 100))
+ d.add_relation_type('Side Effect: Abdominal pain', 1, 1,
+ _gen_adj_mat(100, 100))
+
+ d_train, d_val, d_test = train_val_test_split(d, 0.8, 0.1, 0.1)
diff --git a/tests/icosagon/test_batch.py b/tests/icosagon/test_batch.py
new file mode 100644
index 0000000..b5882db
--- /dev/null
+++ b/tests/icosagon/test_batch.py
@@ -0,0 +1,187 @@
+from icosagon.batch import PredictionsBatch, \
+ FlatPredictions, \
+ flatten_predictions, \
+ BatchIndices, \
+ gather_batch_indices
+from icosagon.declayer import Predictions, \
+ RelationPredictions, \
+ RelationFamilyPredictions
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+from icosagon.data import Data
+import torch
+import pytest
+
+
+def test_flat_predictions_01():
+ pred = FlatPredictions(torch.tensor([0, 1, 0, 1]),
+ torch.tensor([1, 0, 1, 0]), 'train')
+
+ assert torch.all(pred.predictions == torch.tensor([0, 1, 0, 1]))
+ assert torch.all(pred.truth == torch.tensor([1, 0, 1, 0]))
+ assert pred.part_type == 'train'
+
+
+def test_flatten_predictions_01():
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ pred_flat = flatten_predictions(pred, part_type='train')
+
+ assert torch.all(pred_flat.predictions == \
+ torch.tensor([1, 0, 1, 0, 1, 1, 0, 1, 0, 1], dtype=torch.float32))
+ assert torch.all(pred_flat.truth == \
+ torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=torch.float32))
+ assert pred_flat.part_type == 'train'
+
+
+def test_flatten_predictions_02():
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ pred_flat = flatten_predictions(pred, part_type='val')
+
+ assert len(pred_flat.predictions) == 0
+ assert len(pred_flat.truth) == 0
+ assert pred_flat.part_type == 'val'
+
+
+def test_flatten_predictions_03():
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ pred_flat = flatten_predictions(pred, part_type='test')
+
+ assert len(pred_flat.predictions) == 0
+ assert len(pred_flat.truth) == 0
+ assert pred_flat.part_type == 'test'
+
+
+def test_flatten_predictions_04():
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ with pytest.raises(TypeError):
+ pred_flat = flatten_predictions(1, part_type='test')
+
+ with pytest.raises(ValueError):
+ pred_flat = flatten_predictions(pred, part_type='x')
+
+
+def test_flatten_predictions_05():
+ x = torch.rand(5000)
+ y = torch.cat([ x, x ])
+ z = torch.cat([ torch.ones(5000), torch.zeros(5000) ])
+
+ rel_pred = RelationPredictions(
+ TrainValTest(x, torch.zeros(0), torch.zeros(0)),
+ TrainValTest(x, torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ for _ in range(10):
+ pred_flat = flatten_predictions(pred, part_type='train')
+ assert torch.all(pred_flat.predictions == y)
+ assert torch.all(pred_flat.truth == z)
+ assert pred_flat.part_type == 'train'
+
+
+def test_batch_indices_01():
+ indices = BatchIndices(torch.tensor([0, 1, 2, 3, 4]), 'train')
+ assert torch.all(indices.indices == torch.tensor([0, 1, 2, 3, 4]))
+ assert indices.part_type == 'train'
+
+
+def test_gather_batch_indices_01():
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ pred_flat = flatten_predictions(pred, part_type='train')
+
+ indices = BatchIndices(torch.tensor([0, 2, 4, 5, 7, 9]), 'train')
+
+ (input, target) = gather_batch_indices(pred_flat, indices)
+ assert torch.all(input == \
+ torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.float32))
+ assert torch.all(target == \
+ torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.float32))
+
+
+def test_predictions_batch_01():
+ d = Data()
+ d.add_node_type('Dummy', 5)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.tensor([
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0]
+ ], dtype=torch.float32))
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ assert len(prep_d.relation_families) == 1
+ assert len(prep_d.relation_families[0].relation_types) == 1
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
+ assert len(prep_d.relation_families[0].relation_types[0].edges_neg.train) == 5
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
+
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ pred_flat = flatten_predictions(pred, part_type='train')
+
+ batch = PredictionsBatch(prep_d, part_type='train', batch_size=1)
+ count = 0
+ lst = []
+ for indices in batch:
+ (input, target) = gather_batch_indices(pred_flat, indices)
+ assert len(input) == 1
+ assert len(target) == 1
+ lst.append((input[0], target[0]))
+ count += 1
+ assert lst == [ (1, 1), (0, 1), (1, 1), (0, 1), (1, 1),
+ (1, 0), (0, 0), (1, 0), (0, 0), (1, 0) ]
+
+ assert count == 10
diff --git a/tests/icosagon/test_bulkdec.py b/tests/icosagon/test_bulkdec.py
new file mode 100644
index 0000000..b4ed3d6
--- /dev/null
+++ b/tests/icosagon/test_bulkdec.py
@@ -0,0 +1,240 @@
+from icosagon.data import Data
+from icosagon.bulkdec import BulkDecodeLayer
+from icosagon.input import OneHotInputLayer
+from icosagon.convlayer import DecagonLayer
+import torch
+import pytest
+import time
+import sys
+
+
+def test_bulk_decode_layer_01():
+ data = Data()
+ data.add_node_type('Dummy', 100)
+ fam = data.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[0].count
+
+
+def test_bulk_decode_layer_02():
+ data = Data()
+ data.add_node_type('Foo', 100)
+ data.add_node_type('Bar', 50)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ fam.add_relation_type('Foobar Relation 1',
+ torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
+ torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
+
+
+def test_bulk_decode_layer_03():
+ data = Data()
+ data.add_node_type('Foo', 100)
+ data.add_node_type('Bar', 50)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ fam.add_relation_type('Foobar Relation 1',
+ torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
+ torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
+ fam.add_relation_type('Foobar Relation 2',
+ torch.rand((100, 50), dtype=torch.float32).round().to_sparse(),
+ torch.rand((50, 100), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
+
+
+def test_bulk_decode_layer_03_big():
+ data = Data()
+ data.add_node_type('Foo', 2000)
+ data.add_node_type('Bar', 2100)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ fam.add_relation_type('Foobar Relation 1',
+ torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
+ torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
+ fam.add_relation_type('Foobar Relation 2',
+ torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse(),
+ torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse())
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
+
+
+def test_bulk_decode_layer_03_huge_gpu():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('test_bulk_decode_layer_03_huge_gpu() requires CUDA support')
+
+ device = torch.device('cuda:0')
+ data = Data()
+ data.add_node_type('Foo', 20000)
+ data.add_node_type('Bar', 21000)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ print('Adding Foobar Relation 1...')
+ fam.add_relation_type('Foobar Relation 1',
+ torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
+ torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
+ print('Adding Foobar Relation 2...')
+ fam.add_relation_type('Foobar Relation 2',
+ torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device),
+ torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device))
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+ seq = seq.to(device)
+
+ print('Starting forward pass...')
+ t = time.time()
+ pred = seq(None)
+ print('Elapsed:', time.time() - t)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
+
+
+def test_bulk_decode_layer_04_huge_multirel_gpu():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires CUDA support')
+
+ if torch.cuda.get_device_properties(0).total_memory < 64000000000:
+ pytest.skip('test_bulk_decode_layer_04_huge_multirel_gpu() requires GPU with 64GB of memory')
+
+ device = torch.device('cuda:0')
+ data = Data()
+ data.add_node_type('Foo', 20000)
+ data.add_node_type('Bar', 21000)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ print('Generating adj_mat ...')
+ adj_mat = torch.rand((20000, 21000), dtype=torch.float32).round().to_sparse().to(device)
+ print('Generating adj_mat_back ...')
+ adj_mat_back = torch.rand((21000, 20000), dtype=torch.float32).round().to_sparse().to(device)
+ print('Adding relations ...')
+ for i in range(1300):
+ sys.stdout.write('.')
+ sys.stdout.flush()
+ fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
+ print()
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+ seq = seq.to(device)
+
+ print('Starting forward pass...')
+ t = time.time()
+ pred = seq(None)
+ print('Elapsed:', time.time() - t)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
+
+
+def test_bulk_decode_layer_04_big_multirel_gpu():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('test_bulk_decode_layer_04_big_multirel_gpu() requires CUDA support')
+
+ device = torch.device('cuda:0')
+ data = Data()
+ data.add_node_type('Foo', 2000)
+ data.add_node_type('Bar', 2100)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, False)
+ print('Generating adj_mat ...')
+ adj_mat = torch.rand((2000, 2100), dtype=torch.float32).round().to_sparse().to(device)
+ print('Generating adj_mat_back ...')
+ adj_mat_back = torch.rand((2100, 2000), dtype=torch.float32).round().to_sparse().to(device)
+ print('Adding relations ...')
+ for i in range(1300):
+ sys.stdout.write('.')
+ sys.stdout.flush()
+ fam.add_relation_type(f'Foobar Relation {i}', adj_mat, adj_mat_back)
+ print()
+
+ in_layer = OneHotInputLayer(data)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, data)
+ dec_layer = BulkDecodeLayer(input_dim=d_layer.output_dim, data=data,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+ seq = seq.to(device)
+
+ print('Starting forward pass...')
+ t = time.time()
+ pred = seq(None)
+ print('Elapsed:', time.time() - t)
+
+ assert isinstance(pred, list)
+ assert len(pred) == len(data.relation_families)
+ assert isinstance(pred[0], torch.Tensor)
+ assert len(pred[0].shape) == 3
+ assert len(pred[0]) == len(data.relation_families[0].relation_types)
+ assert pred[0].shape[1] == data.node_types[0].count
+ assert pred[0].shape[2] == data.node_types[1].count
diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py
new file mode 100644
index 0000000..a6daa99
--- /dev/null
+++ b/tests/icosagon/test_convlayer.py
@@ -0,0 +1,300 @@
+from icosagon.input import InputLayer, \
+ OneHotInputLayer
+from icosagon.convlayer import DecagonLayer, \
+ Convolutions
+from icosagon.data import Data
+import torch
+import pytest
+from icosagon.convolve import DropoutGraphConvActivation
+from decagon_pytorch.convolve import MultiDGCA
+import decagon_pytorch.convolve
+
+
+def _make_symmetric(x: torch.Tensor):
+ x = (x + x.transpose(0, 1)) / 2
+ return x
+
+
+def _symmetric_random(n_rows, n_columns):
+ return _make_symmetric(torch.rand((n_rows, n_columns),
+ dtype=torch.float32).round())
+
+
+def _some_data_with_interactions():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ fam.add_relation_type('Target',
+ torch.rand((100, 1000), dtype=torch.float32).round())
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, True)
+ fam.add_relation_type('Interaction',
+ _symmetric_random(1000, 1000))
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, True)
+ fam.add_relation_type('Side Effect: Nausea',
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Infertility',
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Death',
+ _symmetric_random(100, 100))
+ return d
+
+
+def test_decagon_layer_01():
+ d = _some_data_with_interactions()
+ in_layer = InputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ seq = torch.nn.Sequential(in_layer, d_layer)
+ _ = seq(None) # dummy call
+
+
+def test_decagon_layer_02():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ seq = torch.nn.Sequential(in_layer, d_layer)
+ _ = seq(None) # dummy call
+
+
+def test_decagon_layer_03():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+
+ assert d_layer.input_dim == [ 1000, 100 ]
+ assert d_layer.output_dim == [ 32, 32 ]
+ assert d_layer.data == d
+ assert d_layer.keep_prob == 1.
+ assert d_layer.rel_activation(0.5) == 0.5
+ x = torch.tensor([-1, 0, 0.5, 1])
+ assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
+
+ assert not d_layer.is_sparse
+ assert len(d_layer.next_layer_repr) == 2
+
+ for i in range(2):
+ assert len(d_layer.next_layer_repr[i]) == 2
+ assert isinstance(d_layer.next_layer_repr[i], torch.nn.ModuleList)
+ assert isinstance(d_layer.next_layer_repr[i][0], Convolutions)
+ assert isinstance(d_layer.next_layer_repr[i][0].node_type_column, int)
+ assert isinstance(d_layer.next_layer_repr[i][0].convolutions, torch.nn.ModuleList)
+ assert all([
+ isinstance(dgca, DropoutGraphConvActivation) \
+ for dgca in d_layer.next_layer_repr[i][0].convolutions
+ ])
+ assert all([
+ dgca.output_dim == 32 \
+ for dgca in d_layer.next_layer_repr[i][0].convolutions
+ ])
+
+
+def test_decagon_layer_04():
+ # check if it is equivalent to MultiDGCA, as it should be
+
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
+ fam.add_relation_type('Dummy Relation',
+ _symmetric_random(100, 100).to_sparse())
+
+ in_layer = OneHotInputLayer(d)
+
+ multi_dgca = MultiDGCA([10], 32,
+ [r.adjacency_matrix for r in fam.relation_types],
+ keep_prob=1., activation=lambda x: x)
+
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d,
+ keep_prob=1., rel_activation=lambda x: x,
+ layer_activation=lambda x: x)
+
+ assert isinstance(d_layer.next_layer_repr[0][0].convolutions[0],
+ DropoutGraphConvActivation)
+
+ weight = d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight
+ assert isinstance(weight, torch.Tensor)
+
+ assert len(multi_dgca.dgca) == 1
+ assert isinstance(multi_dgca.dgca[0],
+ decagon_pytorch.convolve.DropoutGraphConvActivation)
+
+ multi_dgca.dgca[0].graph_conv.weight = weight
+
+ seq = torch.nn.Sequential(in_layer, d_layer)
+ out_d_layer = seq(None)
+ out_multi_dgca = multi_dgca(list(in_layer(None)))
+
+ assert isinstance(out_d_layer, list)
+ assert len(out_d_layer) == 1
+
+ assert torch.all(out_d_layer[0] == out_multi_dgca)
+
+
+def test_decagon_layer_05():
+ # check if it is equivalent to MultiDGCA, as it should be
+ # this time for two relations, same edge type
+
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
+ fam.add_relation_type('Dummy Relation 1',
+ _symmetric_random(100, 100).to_sparse())
+ fam.add_relation_type('Dummy Relation 2',
+ _symmetric_random(100, 100).to_sparse())
+
+ in_layer = OneHotInputLayer(d)
+
+ multi_dgca = MultiDGCA([100, 100], 32,
+ [r.adjacency_matrix for r in fam.relation_types],
+ keep_prob=1., activation=lambda x: x)
+
+ d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
+ keep_prob=1., rel_activation=lambda x: x,
+ layer_activation=lambda x: x)
+
+ assert all([
+ isinstance(dgca, DropoutGraphConvActivation) \
+ for dgca in d_layer.next_layer_repr[0][0].convolutions
+ ])
+
+ weight = [ dgca.graph_conv.weight \
+ for dgca in d_layer.next_layer_repr[0][0].convolutions ]
+ assert all([
+ isinstance(w, torch.Tensor) \
+ for w in weight
+ ])
+
+ assert len(multi_dgca.dgca) == 2
+ for i in range(2):
+ assert isinstance(multi_dgca.dgca[i],
+ decagon_pytorch.convolve.DropoutGraphConvActivation)
+
+ for i in range(2):
+ multi_dgca.dgca[i].graph_conv.weight = weight[i]
+
+ seq = torch.nn.Sequential(in_layer, d_layer)
+ out_d_layer = seq(None)
+ x = in_layer(None)
+ out_multi_dgca = multi_dgca([ x[0], x[0] ])
+
+ assert isinstance(out_d_layer, list)
+ assert len(out_d_layer) == 1
+
+ assert torch.all(out_d_layer[0] == out_multi_dgca)
+
+
+class Dummy1(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.whatever = torch.nn.Parameter(torch.rand((10, 10)))
+
+
+class Dummy2(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = Dummy1()
+
+
+class Dummy3(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = [ Dummy1() ]
+
+
+class Dummy4(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = torch.nn.ModuleList([ Dummy1() ])
+
+
+class Dummy5(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = [ torch.nn.ModuleList([ Dummy1() ]) ]
+
+
+class Dummy6(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList([ Dummy1() ]) ])
+
+
+class Dummy7(torch.nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.dummy_1 = torch.nn.ModuleList([ torch.nn.ModuleList() ])
+ self.dummy_1[0].append(Dummy1())
+
+
+def test_module_nesting_01():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_2 = Dummy2()
+ dummy_2 = dummy_2.to(device)
+ assert dummy_2.dummy_1.whatever.device == device
+
+
+def test_module_nesting_02():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_3 = Dummy3()
+ dummy_3 = dummy_3.to(device)
+ assert dummy_3.dummy_1[0].whatever.device != device
+
+
+def test_module_nesting_03():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_4 = Dummy4()
+ dummy_4 = dummy_4.to(device)
+ assert dummy_4.dummy_1[0].whatever.device == device
+
+
+def test_module_nesting_04():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_5 = Dummy5()
+ dummy_5 = dummy_5.to(device)
+ assert dummy_5.dummy_1[0][0].whatever.device != device
+
+
+def test_module_nesting_05():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_6 = Dummy6()
+ dummy_6 = dummy_6.to(device)
+ assert dummy_6.dummy_1[0][0].whatever.device == device
+
+
+def test_module_nesting_06():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA support on this host')
+ device = torch.device('cuda:0')
+ dummy_7 = Dummy7()
+ dummy_7 = dummy_7.to(device)
+ assert dummy_7.dummy_1[0][0].whatever.device == device
+
+
+def test_parameter_count_01():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, True)
+ fam.add_relation_type('Dummy Relation 1',
+ _symmetric_random(100, 100).to_sparse())
+ fam.add_relation_type('Dummy Relation 2',
+ _symmetric_random(100, 100).to_sparse())
+
+ in_layer = OneHotInputLayer(d)
+ assert len(list(in_layer.parameters())) == 1
+
+ d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,
+ keep_prob=1., rel_activation=lambda x: x,
+ layer_activation=lambda x: x)
+ assert len(list(d_layer.parameters())) == 2
diff --git a/tests/icosagon/test_convolve.py b/tests/icosagon/test_convolve.py
new file mode 100644
index 0000000..d4df6ea
--- /dev/null
+++ b/tests/icosagon/test_convolve.py
@@ -0,0 +1,206 @@
+from icosagon.convolve import GraphConv, \
+ DropoutGraphConvActivation
+import torch
+from icosagon.dropout import dropout
+
+
+def _test_graph_conv_01(use_sparse: bool):
+ adj_mat = torch.rand((10, 20))
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = torch.ceil(adj_mat)
+
+ node_reprs = torch.eye(20)
+
+ graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ graph_conv.weight = torch.nn.Parameter(torch.eye(20))
+
+ res = graph_conv(node_reprs)
+ assert torch.all(res == adj_mat)
+
+
+def _test_graph_conv_02(use_sparse: bool):
+ adj_mat = torch.rand((10, 20))
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = torch.ceil(adj_mat)
+
+ node_reprs = torch.eye(20)
+
+ graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2)
+
+ res = graph_conv(node_reprs)
+ assert torch.all(res == adj_mat * 2)
+
+
+def _test_graph_conv_03(use_sparse: bool):
+ adj_mat = torch.tensor([
+ [1, 0, 1, 0, 1, 0], # [1, 0, 0]
+ [1, 0, 1, 0, 0, 1], # [1, 0, 0]
+ [1, 1, 0, 1, 0, 0], # [0, 1, 0]
+ [0, 0, 0, 1, 0, 1], # [0, 1, 0]
+ [1, 1, 1, 1, 1, 1], # [0, 0, 1]
+ [0, 0, 0, 1, 1, 1] # [0, 0, 1]
+ ], dtype=torch.float32)
+
+ expect = torch.tensor([
+ [1, 1, 1],
+ [1, 1, 1],
+ [2, 1, 0],
+ [0, 1, 1],
+ [2, 2, 2],
+ [0, 1, 2]
+ ], dtype=torch.float32)
+
+ node_reprs = torch.eye(6)
+
+ graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ graph_conv.weight = torch.nn.Parameter(torch.tensor([
+ [1, 0, 0],
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [0, 0, 1]
+ ], dtype=torch.float32))
+
+ res = graph_conv(node_reprs)
+ assert torch.all(res == expect)
+
+
+def test_graph_conv_dense_01():
+ _test_graph_conv_01(use_sparse=False)
+
+
+def test_graph_conv_dense_02():
+ _test_graph_conv_02(use_sparse=False)
+
+
+def test_graph_conv_dense_03():
+ _test_graph_conv_03(use_sparse=False)
+
+
+def test_graph_conv_sparse_01():
+ _test_graph_conv_01(use_sparse=True)
+
+
+def test_graph_conv_sparse_02():
+ _test_graph_conv_02(use_sparse=True)
+
+
+def test_graph_conv_sparse_03():
+ _test_graph_conv_03(use_sparse=True)
+
+
+def _test_dropout_graph_conv_activation_01(use_sparse: bool):
+ adj_mat = torch.rand((10, 20))
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = torch.ceil(adj_mat)
+ node_reprs = torch.eye(20)
+
+ conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat, keep_prob=1.,
+ activation=lambda x: x)
+
+ conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ conv_2.weight = conv_1.graph_conv.weight
+
+ res_1 = conv_1(node_reprs)
+ res_2 = conv_2(node_reprs)
+
+ print('res_1:', res_1.detach().cpu().numpy())
+ print('res_2:', res_2.detach().cpu().numpy())
+
+ assert torch.all(res_1 == res_2)
+
+
+def _test_dropout_graph_conv_activation_02(use_sparse: bool):
+ adj_mat = torch.rand((10, 20))
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = torch.ceil(adj_mat)
+ node_reprs = torch.eye(20)
+
+ conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat, keep_prob=1.,
+ activation=lambda x: x * 2)
+
+ conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ conv_2.weight = conv_1.graph_conv.weight
+
+ res_1 = conv_1(node_reprs)
+ res_2 = conv_2(node_reprs)
+
+ print('res_1:', res_1.detach().cpu().numpy())
+ print('res_2:', res_2.detach().cpu().numpy())
+
+ assert torch.all(res_1 == res_2 * 2)
+
+
+def _test_dropout_graph_conv_activation_03(use_sparse: bool):
+ adj_mat = torch.rand((10, 20))
+ adj_mat[adj_mat < .5] = 0
+ adj_mat = torch.ceil(adj_mat)
+ node_reprs = torch.eye(20)
+
+ conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat, keep_prob=.5,
+ activation=lambda x: x)
+
+ conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \
+ if use_sparse else adj_mat)
+ conv_2.weight = conv_1.graph_conv.weight
+
+ torch.random.manual_seed(0)
+ res_1 = conv_1(node_reprs)
+
+ torch.random.manual_seed(0)
+ res_2 = conv_2(dropout(node_reprs, 0.5))
+
+ print('res_1:', res_1.detach().cpu().numpy())
+ print('res_2:', res_2.detach().cpu().numpy())
+
+ assert torch.all(res_1 == res_2)
+
+
+def test_dropout_graph_conv_activation_dense_01():
+ _test_dropout_graph_conv_activation_01(False)
+
+
+def test_dropout_graph_conv_activation_sparse_01():
+ _test_dropout_graph_conv_activation_01(True)
+
+
+def test_dropout_graph_conv_activation_dense_02():
+ _test_dropout_graph_conv_activation_02(False)
+
+
+def test_dropout_graph_conv_activation_sparse_02():
+ _test_dropout_graph_conv_activation_02(True)
+
+
+def test_dropout_graph_conv_activation_dense_03():
+ _test_dropout_graph_conv_activation_03(False)
+
+
+def test_dropout_graph_conv_activation_sparse_03():
+ _test_dropout_graph_conv_activation_03(True)
+
+
+def test_graph_conv_parameter_count_01():
+ adj_mat = torch.rand((10, 20)).round()
+
+ conv = GraphConv(20, 20, adj_mat)
+
+ assert len(list(conv.parameters())) == 1
+
+
+def test_dropout_graph_conv_activation_parameter_count_01():
+ adj_mat = torch.rand((10, 20)).round()
+
+ conv = DropoutGraphConvActivation(20, 20, adj_mat)
+
+ assert len(list(conv.parameters())) == 1
diff --git a/tests/icosagon/test_data.py b/tests/icosagon/test_data.py
new file mode 100644
index 0000000..57060e9
--- /dev/null
+++ b/tests/icosagon/test_data.py
@@ -0,0 +1,143 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from icosagon.data import Data, \
+ _equal, \
+ RelationFamily
+from icosagon.decode import DEDICOMDecoder
+import torch
+import pytest
+
+
+def test_equal_01():
+ x = torch.rand((10, 10))
+ y = torch.rand((10, 10)).round().to_sparse()
+
+ print('x == x ?')
+ assert torch.all(_equal(x, x))
+ print('y == y ?')
+ assert torch.all(_equal(y, y))
+ print('x == y ?')
+ with pytest.raises(ValueError):
+ _equal(x, y)
+
+ print('y == z ?')
+ z = torch.rand((10, 10)).round().to_sparse()
+ assert not torch.all(_equal(y, z))
+
+
+def test_relation_family_01():
+ d = Data()
+ d.add_node_type('Whatever', 10)
+
+ fam = RelationFamily(d, 'Dummy-Dummy', 0, 0, True, DEDICOMDecoder)
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', None, None)
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', 'bad-value', None)
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', None, 'bad-value')
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', torch.rand((5, 5)), None)
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', None, torch.rand((5, 5)))
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), torch.rand((10, 10)))
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Dummy-Dummy', torch.rand((10, 10)), None)
+
+
+def test_relation_family_02():
+ d = Data()
+ d.add_node_type('A', 10)
+ d.add_node_type('B', 5)
+
+ fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder)
+
+ with pytest.raises(ValueError):
+ fam.add_relation_type('A-B', torch.rand((10, 5)).round(),
+ torch.rand((5, 10)).round())
+
+
+def test_relation_family_03():
+ d = Data()
+ d.add_node_type('A', 10)
+ d.add_node_type('B', 5)
+
+ fam = RelationFamily(d, 'A-B', 0, 1, True, DEDICOMDecoder)
+
+ fam.add_relation_type('A-B', torch.rand((10, 5)).round())
+
+ assert torch.all(fam.relation_types[0].adjacency_matrix.transpose(0, 1) == \
+ fam.relation_types[0].adjacency_matrix_backward)
+
+
+def test_data_01():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ dummy_0 = torch.zeros((100, 1000))
+ dummy_1 = torch.zeros((1000, 100))
+ dummy_2 = torch.zeros((100, 100))
+ dummy_3 = torch.zeros((1000, 1000))
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ fam.add_relation_type('Target', dummy_0)
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, True)
+ fam.add_relation_type('Interaction', dummy_3)
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, True)
+ fam.add_relation_type('Side Effect: Nausea', dummy_2)
+ fam.add_relation_type('Side Effect: Infertility', dummy_2)
+ fam.add_relation_type('Side Effect: Death', dummy_2)
+
+ print(d)
+
+
+def test_data_02():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+
+ dummy_0 = torch.zeros((100, 1000))
+ dummy_1 = torch.zeros((1000, 100))
+ dummy_2 = torch.zeros((100, 100))
+ dummy_3 = torch.zeros((1000, 1000))
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Target', dummy_1)
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, True)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Interaction', dummy_2)
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, True)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Side Effect: Nausea', dummy_3)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Side Effect: Infertility', dummy_3)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Side Effect: Death', dummy_3)
+ print(d)
+
+
+def test_data_03():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ with pytest.raises(ValueError):
+ fam.add_relation_type('Target', None)
+ print(d)
diff --git a/tests/icosagon/test_databatch.py b/tests/icosagon/test_databatch.py
new file mode 100644
index 0000000..b36b5da
--- /dev/null
+++ b/tests/icosagon/test_databatch.py
@@ -0,0 +1,126 @@
+from icosagon.databatch import DataBatcher, \
+ BatchedData, \
+ BatchedDataPointer, \
+ batched_data_skeleton
+from icosagon.data import Data
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+from icosagon.declayer import DecodeLayer
+from icosagon.input import OneHotInputLayer
+import torch
+import time
+
+
+def _some_data():
+ data = Data()
+ data.add_node_type('Foo', 100)
+ data.add_node_type('Bar', 500)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, True)
+ adj_mat = torch.rand(100, 500).round().to_sparse()
+ fam.add_relation_type('Foo-Bar', adj_mat)
+ return data
+
+
+def _some_data_big():
+ data = Data()
+ data.add_node_type('Foo', 2000)
+ data.add_node_type('Bar', 2100)
+ fam = data.add_relation_family('Foo-Bar', 0, 1, True)
+ adj_mat = torch.rand(2000, 2100).round().to_sparse()
+ fam.add_relation_type('Foo-Bar', adj_mat)
+ return data
+
+
+def test_data_batcher_01():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+
+
+def test_data_batcher_02():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ for batch_d in batcher:
+ pass
+
+
+def test_data_batcher_03():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ for batch_d in batcher:
+ edges_list = []
+ for fam in batch_d.relation_families:
+ for rel in fam.relation_types:
+ for edge_type in ['edges_pos', 'edges_neg',
+ 'edges_back_pos', 'edges_back_neg']:
+ for part_type in ['train', 'val', 'test']:
+ edges = getattr(getattr(rel, edge_type), part_type)
+ edges_list.append(edges)
+ assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
+
+
+def test_data_batcher_04():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ edges_list = []
+ for batch_d in batcher:
+ for fam in batch_d.relation_families:
+ for rel in fam.relation_types:
+ for edge_type in ['edges_pos', 'edges_neg',
+ 'edges_back_pos', 'edges_back_neg']:
+ for part_type in ['train', 'val', 'test']:
+ edges = getattr(getattr(rel, edge_type), part_type)
+ edges_list.append(edges)
+ assert sum([ len(edges) for edges in edges_list ]) == \
+ torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
+
+
+def test_data_batcher_05():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ for batch_d in batcher:
+ edges_list = []
+ for fam in batch_d.relation_families:
+ for rel in fam.relation_types:
+ for edge_type in ['edges_pos', 'edges_neg',
+ 'edges_back_pos', 'edges_back_neg']:
+ for part_type in ['train', 'val', 'test']:
+ edges = getattr(getattr(rel, edge_type), part_type)
+ edges_list.append(edges)
+ assert all([ len(edges) <= 512 for edges in edges_list ])
+ assert not all([ len(edges) == 0 for edges in edges_list ])
+ print(sum(map(len, edges_list)))
+
+
+def test_batch_decode_01():
+ data = _some_data()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
+ in_repr = [ torch.rand(100, 32),
+ torch.rand(500, 32) ]
+ dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
+ t = time.time()
+ for batched_data in batcher:
+ ptr.batched_data = batched_data
+ _ = dec_layer(in_repr)
+ print('Elapsed:', time.time() - t)
+
+
+def test_batch_decode_02():
+ data = _some_data_big()
+ prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
+ batcher = DataBatcher(prep_d, 512)
+ ptr = BatchedDataPointer(batched_data_skeleton(prep_d))
+ in_repr = [ torch.rand(2000, 32),
+ torch.rand(2100, 32) ]
+ dec_layer = DecodeLayer([ 32, 32 ], prep_d, batched_data_pointer=ptr)
+ t = time.time()
+ for batched_data in batcher:
+ ptr.batched_data = batched_data
+ _ = dec_layer(in_repr)
+ print('Elapsed:', time.time() - t)
diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py
new file mode 100644
index 0000000..014d304
--- /dev/null
+++ b/tests/icosagon/test_declayer.py
@@ -0,0 +1,287 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from icosagon.input import OneHotInputLayer
+from icosagon.convolve import DropoutGraphConvActivation
+from icosagon.convlayer import DecagonLayer
+from icosagon.declayer import DecodeLayer, \
+ Predictions, \
+ RelationFamilyPredictions, \
+ RelationPredictions
+from icosagon.decode import DEDICOMDecoder, \
+ InnerProductDecoder
+from icosagon.data import Data
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+import torch
+
+
+def test_decode_layer_01():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ seq = torch.nn.Sequential(in_layer, d_layer)
+ last_layer_repr = seq(None)
+
+ dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
+ activation=lambda x: x)
+ pred = dec(last_layer_repr)
+
+ assert isinstance(pred, Predictions)
+
+ assert isinstance(pred.relation_families, list)
+ assert len(pred.relation_families) == 1
+ assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
+
+ assert isinstance(pred.relation_families[0].relation_types, list)
+ assert len(pred.relation_families[0].relation_types) == 1
+ assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
+
+ tmp = pred.relation_families[0].relation_types[0]
+ assert isinstance(tmp.edges_pos, TrainValTest)
+ assert isinstance(tmp.edges_neg, TrainValTest)
+ assert isinstance(tmp.edges_back_pos, TrainValTest)
+ assert isinstance(tmp.edges_back_neg, TrainValTest)
+
+
+def test_decode_layer_02():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, Predictions)
+ assert len(pred.relation_families) == 1
+ assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
+ assert isinstance(pred.relation_families[0].relation_types, list)
+ assert len(pred.relation_families[0].relation_types) == 1
+
+
+def test_decode_layer_03():
+ d = Data()
+ d.add_node_type('Dummy 1', 100)
+ d.add_node_type('Dummy 2', 100)
+ fam = d.add_relation_family('Dummy 1-Dummy 2', 0, 1, True)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+ assert isinstance(pred, Predictions)
+ assert len(pred.relation_families) == 1
+ assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
+ assert isinstance(pred.relation_families[0].relation_types, list)
+ assert len(pred.relation_families[0].relation_types) == 1
+ assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
+
+
+def test_decode_layer_04():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+ assert len(d.relation_families) == 0
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ in_layer = OneHotInputLayer(d)
+ d_layer = DecagonLayer(in_layer.output_dim, 32, d)
+ dec_layer = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, d_layer, dec_layer)
+
+ pred = seq(None)
+
+ assert isinstance(pred, Predictions)
+ assert len(pred.relation_families) == 0
+
+
+def test_decode_layer_05():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ mat = torch.rand((10, 10))
+ mat = (mat + mat.transpose(0, 1)) / 2
+ mat = mat.round()
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, True,
+ decoder_class=InnerProductDecoder)
+ fam.add_relation_type('Dummy Rel', mat.to_sparse())
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ in_layer = OneHotInputLayer(d)
+ conv_layer = DecagonLayer(in_layer.output_dim, 32, prep_d,
+ rel_activation=lambda x: x, layer_activation=lambda x: x)
+ dec_layer = DecodeLayer(conv_layer.output_dim, prep_d,
+ keep_prob=1., activation=lambda x: x)
+ seq = torch.nn.Sequential(in_layer, conv_layer, dec_layer)
+
+ pred = seq(None)
+ rel_pred = pred.relation_families[0].relation_types[0]
+
+ for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
+ edge_pred = getattr(rel_pred, edge_type)
+ assert isinstance(edge_pred, TrainValTest)
+ for part_type in ['train', 'val', 'test']:
+ part_pred = getattr(edge_pred, part_type)
+ assert isinstance(part_pred, torch.Tensor)
+ assert len(part_pred.shape) == 1
+ print(edge_type, part_type, part_pred.shape)
+ if (edge_type, part_type) not in [('edges_pos', 'train'), ('edges_neg', 'train')]:
+ assert part_pred.shape[0] == 0
+ else:
+ assert part_pred.shape[0] > 0
+
+ prep_rel = prep_d.relation_families[0].relation_types[0]
+ assert len(rel_pred.edges_pos.train) == len(prep_rel.edges_pos.train)
+ assert len(rel_pred.edges_neg.train) == len(prep_rel.edges_neg.train)
+
+ assert len(prep_rel.edges_pos.train) == torch.sum(mat)
+
+ # print('Predictions for positive edges:')
+ # print(rel_pred.edges_pos.train)
+ # print('Predictions for negative edges:')
+ # print(rel_pred.edges_neg.train)
+
+ repr_in = in_layer(None)
+ assert isinstance(repr_in, torch.nn.ParameterList)
+ assert len(repr_in) == 1
+ assert isinstance(repr_in[0], torch.Tensor)
+ assert torch.all(repr_in[0].to_dense() == torch.eye(10))
+
+ assert len(conv_layer.next_layer_repr[0]) == 1
+ assert len(conv_layer.next_layer_repr[0][0].convolutions) == 1
+ assert conv_layer.rel_activation(0) == 0
+ assert conv_layer.rel_activation(1) == 1
+ assert conv_layer.rel_activation(-1) == -1
+ assert conv_layer.layer_activation(0) == 0
+ assert conv_layer.layer_activation(1) == 1
+ assert conv_layer.layer_activation(-1) == -1
+
+ graph_conv = conv_layer.next_layer_repr[0][0].convolutions[0]
+ assert isinstance(graph_conv, DropoutGraphConvActivation)
+ assert graph_conv.activation(0) == 0
+ assert graph_conv.activation(1) == 1
+ assert graph_conv.activation(-1) == -1
+ weight = graph_conv.graph_conv.weight
+ adj_mat = prep_d.relation_families[0].relation_types[0].adjacency_matrix
+ repr_conv = torch.sparse.mm(repr_in[0], weight)
+ repr_conv = torch.mm(adj_mat, repr_conv)
+ repr_conv = torch.nn.functional.normalize(repr_conv, p=2, dim=1)
+ repr_conv_expect = conv_layer(repr_in)[0]
+ print('repr_conv:\n', repr_conv)
+ # print(repr_conv_expect)
+ assert torch.all(repr_conv == repr_conv_expect)
+ assert repr_conv.shape[1] == 32
+
+ dec = InnerProductDecoder(32, 1, keep_prob=1., activation=lambda x: x)
+ x, y = torch.meshgrid(torch.arange(0, 10), torch.arange(0, 10))
+ x = x.flatten()
+ y = y.flatten()
+ repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
+ repr_dec_expect = repr_dec_expect.view(10, 10)
+
+ repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
+ # repr_dec = torch.flatten(repr_dec)
+ # repr_dec -= torch.eye(10)
+ assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
+
+ repr_dec_expect = torch.zeros((10, 10))
+ x = prep_d.relation_families[0].relation_types[0].edges_pos.train
+ repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
+ x = prep_d.relation_families[0].relation_types[0].edges_neg.train
+ repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
+ print(repr_dec)
+ print(repr_dec_expect)
+
+ repr_dec = torch.zeros((10, 10))
+ x = prep_d.relation_families[0].relation_types[0].edges_pos.train
+ repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
+ x = prep_d.relation_families[0].relation_types[0].edges_neg.train
+ repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
+
+ assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
+
+ #print(prep_rel.edges_pos.train)
+ #print(prep_rel.edges_neg.train)
+
+ # assert isinstance(edge_pred.train)
+ # assert isinstance(rel_pred.edges_pos, TrainValTest)
+ # assert isinstance(rel_pred.edges_neg, TrainValTest)
+ # assert isinstance(rel_pred.edges_back_pos, TrainValTest)
+ # assert isinstance(rel_pred.edges_back_neg, TrainValTest)
+
+
+def test_decode_layer_parameter_count_01():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
+ activation=lambda x: x)
+
+ assert len(list(dec.parameters())) == 2
+
+
+def test_decode_layer_parameter_count_02():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ fam.add_relation_type('Dummy Relation 2',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
+ activation=lambda x: x)
+
+ assert len(list(dec.parameters())) == 3
+
+
+def test_decode_layer_parameter_count_03():
+ d = Data()
+ d.add_node_type('Dummy', 100)
+
+ for _ in range(2):
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Relation 1',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+ fam.add_relation_type('Dummy Relation 2',
+ torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ dec = DecodeLayer(input_dim=[ 32 ], data=prep_d, keep_prob=1.,
+ activation=lambda x: x)
+
+ assert len(list(dec.parameters())) == 6
diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py
new file mode 100644
index 0000000..b8c9cea
--- /dev/null
+++ b/tests/icosagon/test_decode.py
@@ -0,0 +1,240 @@
+from icosagon.decode import DEDICOMDecoder, \
+ DistMultDecoder, \
+ BilinearDecoder, \
+ InnerProductDecoder
+import decagon_pytorch.decode.pairwise
+import torch
+
+
+def test_dedicom_decoder_01():
+ repr_ = torch.rand(20, 32)
+ dec_1 = DEDICOMDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+ dec_2 = decagon_pytorch.decode.pairwise.DEDICOMDecoder(32, 7, drop_prob=0.,
+ activation=torch.sigmoid)
+ dec_2.global_interaction = dec_1.global_interaction
+ dec_2.local_variation = dec_1.local_variation
+
+ res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
+ res_2 = dec_2(repr_, repr_)
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] == res_2[i])
+
+
+def test_dist_mult_decoder_01():
+ repr_ = torch.rand(20, 32)
+ dec_1 = DistMultDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+ dec_2 = decagon_pytorch.decode.pairwise.DistMultDecoder(32, 7, drop_prob=0.,
+ activation=torch.sigmoid)
+ dec_2.relation = dec_1.relation
+
+ res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
+ res_2 = dec_2(repr_, repr_)
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] == res_2[i])
+
+
+def test_bilinear_decoder_01():
+ repr_ = torch.rand(20, 32)
+ dec_1 = BilinearDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+ dec_2 = decagon_pytorch.decode.pairwise.BilinearDecoder(32, 7, drop_prob=0.,
+ activation=torch.sigmoid)
+ dec_2.relation = dec_1.relation
+
+ res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
+ res_2 = dec_2(repr_, repr_)
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] == res_2[i])
+
+
+def test_inner_product_decoder_01():
+ repr_ = torch.rand(20, 32)
+ dec_1 = InnerProductDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+ dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0.,
+ activation=torch.sigmoid)
+
+ res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
+ res_2 = dec_2(repr_, repr_)
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] == res_2[i])
+
+
+def test_is_dedicom_not_symmetric_01():
+ repr_1 = torch.rand(20, 32)
+ repr_2 = torch.rand(20, 32)
+ dec = DEDICOMDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
+ res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
+
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert not torch.all(res_1[i] - res_2[i] < 0.000001)
+
+
+def test_is_dist_mult_symmetric_01():
+ repr_1 = torch.rand(20, 32)
+ repr_2 = torch.rand(20, 32)
+ dec = DistMultDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
+ res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
+
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] - res_2[i] < 0.000001)
+
+
+def test_is_bilinear_not_symmetric_01():
+ repr_1 = torch.rand(20, 32)
+ repr_2 = torch.rand(20, 32)
+ dec = BilinearDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
+ res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert not torch.all(res_1[i] - res_2[i] < 0.000001)
+
+
+def test_is_inner_product_symmetric_01():
+ repr_1 = torch.rand(20, 32)
+ repr_2 = torch.rand(20, 32)
+ dec = InnerProductDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res_1 = [ dec(repr_1, repr_2, k) for k in range(7) ]
+ res_2 = [ dec(repr_2, repr_1, k) for k in range(7) ]
+
+ assert isinstance(res_1, list)
+ assert isinstance(res_2, list)
+
+ assert len(res_1) == len(res_2)
+
+ for i in range(len(res_1)):
+ assert torch.all(res_1[i] - res_2[i] < 0.000001)
+
+
+def test_empty_dedicom_decoder_01():
+ repr_ = torch.rand(0, 32)
+ dec = DEDICOMDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res = [ dec(repr_, repr_, k) for k in range(7) ]
+
+ assert isinstance(res, list)
+
+ for i in range(len(res)):
+ assert res[i].shape == (0,)
+
+
+def test_empty_dist_mult_decoder_01():
+ repr_ = torch.rand(0, 32)
+ dec = DistMultDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res = [ dec(repr_, repr_, k) for k in range(7) ]
+
+ assert isinstance(res, list)
+
+ for i in range(len(res)):
+ assert res[i].shape == (0,)
+
+
+def test_empty_bilinear_decoder_01():
+ repr_ = torch.rand(0, 32)
+ dec = BilinearDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res = [ dec(repr_, repr_, k) for k in range(7) ]
+
+ assert isinstance(res, list)
+
+ for i in range(len(res)):
+ assert res[i].shape == (0,)
+
+
+def test_empty_inner_product_decoder_01():
+ repr_ = torch.rand(0, 32)
+ dec = InnerProductDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ res = [ dec(repr_, repr_, k) for k in range(7) ]
+
+ assert isinstance(res, list)
+
+ for i in range(len(res)):
+ assert res[i].shape == (0,)
+
+
+def test_dedicom_decoder_parameter_count_01():
+ dec = DEDICOMDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ assert len(list(dec.parameters())) == 8
+
+
+def test_dist_mult_decoder_parameter_count_01():
+ dec = DistMultDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ assert len(list(dec.parameters())) == 7
+
+
+def test_bilinear_decoder_parameter_count_01():
+ dec = BilinearDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ assert len(list(dec.parameters())) == 7
+
+
+def test_inner_product_decoder_parameter_count_01():
+ dec = InnerProductDecoder(32, 7, keep_prob=1.,
+ activation=torch.sigmoid)
+
+ assert len(list(dec.parameters())) == 0
diff --git a/tests/icosagon/test_dropout.py b/tests/icosagon/test_dropout.py
new file mode 100644
index 0000000..374f0a9
--- /dev/null
+++ b/tests/icosagon/test_dropout.py
@@ -0,0 +1,26 @@
+from icosagon.dropout import dropout_sparse, \
+ dropout_dense
+import torch
+import numpy as np
+
+
+def test_dropout_01():
+ for i in range(11):
+ torch.random.manual_seed(i)
+ a = torch.rand((5, 10))
+ a[a < .5] = 0
+
+ keep_prob=i/10. + np.finfo(np.float32).eps
+
+ torch.random.manual_seed(i)
+ b = dropout_dense(a, keep_prob=keep_prob)
+
+ torch.random.manual_seed(i)
+ c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob)
+
+ print('keep_prob:', keep_prob)
+ print('a:', a.detach().cpu().numpy())
+ print('b:', b.detach().cpu().numpy())
+ print('c:', c, c.to_dense().detach().cpu().numpy())
+
+ assert torch.all(b == c.to_dense())
diff --git a/tests/icosagon/test_fastconv.py b/tests/icosagon/test_fastconv.py
new file mode 100644
index 0000000..2003316
--- /dev/null
+++ b/tests/icosagon/test_fastconv.py
@@ -0,0 +1,210 @@
+from icosagon.fastconv import _sparse_diag_cat, \
+ _cat, \
+ FastGraphConv, \
+ FastConvLayer
+from icosagon.data import _equal
+import torch
+import pdb
+import time
+from icosagon.data import Data
+from icosagon.input import OneHotInputLayer
+from icosagon.convlayer import DecagonLayer
+
+
+def _make_symmetric(x: torch.Tensor):
+ x = (x + x.transpose(0, 1)) / 2
+ return x
+
+
+def _symmetric_random(n_rows, n_columns):
+ return _make_symmetric(torch.rand((n_rows, n_columns),
+ dtype=torch.float32).round().to_sparse())
+
+
+def _some_data_with_interactions():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ fam.add_relation_type('Target',
+ torch.rand((100, 1000), dtype=torch.float32).round().to_sparse())
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, True)
+ fam.add_relation_type('Interaction',
+ _symmetric_random(1000, 1000))
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, True)
+ fam.add_relation_type('Side Effect: Nausea',
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Infertility',
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Death',
+ _symmetric_random(100, 100))
+ return d
+
+
+def test_sparse_diag_cat_01():
+ matrices = [ torch.rand(5, 10).round() for _ in range(7) ]
+ ground_truth = torch.zeros(35, 70)
+ ground_truth[0:5, 0:10] = matrices[0]
+ ground_truth[5:10, 10:20] = matrices[1]
+ ground_truth[10:15, 20:30] = matrices[2]
+ ground_truth[15:20, 30:40] = matrices[3]
+ ground_truth[20:25, 40:50] = matrices[4]
+ ground_truth[25:30, 50:60] = matrices[5]
+ ground_truth[30:35, 60:70] = matrices[6]
+ res = _sparse_diag_cat([ m.to_sparse() for m in matrices ])
+ res = res.to_dense()
+ assert torch.all(res == ground_truth)
+
+
+def test_sparse_diag_cat_02():
+ x = [ torch.rand(5, 10).round() for _ in range(7) ]
+ a = [ m.to_sparse() for m in x ]
+ a = _sparse_diag_cat(a)
+ b = torch.rand(70, 64)
+ res = torch.sparse.mm(a, b)
+
+ ground_truth = torch.zeros(35, 64)
+ ground_truth[0:5, :] = torch.mm(x[0], b[0:10])
+ ground_truth[5:10, :] = torch.mm(x[1], b[10:20])
+ ground_truth[10:15, :] = torch.mm(x[2], b[20:30])
+ ground_truth[15:20, :] = torch.mm(x[3], b[30:40])
+ ground_truth[20:25, :] = torch.mm(x[4], b[40:50])
+ ground_truth[25:30, :] = torch.mm(x[5], b[50:60])
+ ground_truth[30:35, :] = torch.mm(x[6], b[60:70])
+
+ assert torch.all(res == ground_truth)
+
+
+def test_cat_01():
+ matrices = [ torch.rand(5, 10) for _ in range(7) ]
+ res = _cat(matrices)
+ assert res.shape == (35, 10)
+ assert not res.is_sparse
+ ground_truth = torch.zeros(35, 10)
+ for i in range(7):
+ ground_truth[i*5:(i+1)*5, :] = matrices[i]
+ assert torch.all(res == ground_truth)
+
+
+def test_cat_02():
+ matrices = [ torch.rand(5, 10) for _ in range(7) ]
+ ground_truth = torch.zeros(35, 10)
+ for i in range(7):
+ ground_truth[i*5:(i+1)*5, :] = matrices[i]
+ res = _cat([ m.to_sparse() for m in matrices ])
+ assert res.shape == (35, 10)
+ assert res.is_sparse
+ assert torch.all(res.to_dense() == ground_truth)
+
+
+def test_fast_graph_conv_01():
+ # pdb.set_trace()
+ adj_mats = [ torch.rand(10, 15).round().to_sparse() \
+ for _ in range(23) ]
+ fgc = FastGraphConv(32, 64, adj_mats)
+ in_repr = torch.rand(15, 32)
+ _ = fgc(in_repr)
+
+
+def test_fast_graph_conv_02():
+ t = time.time()
+ m = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
+ adj_mats = [ m for _ in range(1300) ]
+ print('Generating adj_mats took:', time.time() - t)
+ t = time.time()
+ fgc = FastGraphConv(32, 64, adj_mats)
+ print('FGC constructor took:', time.time() - t)
+ in_repr = torch.rand(2000, 32)
+
+ for _ in range(3):
+ t = time.time()
+ _ = fgc(in_repr)
+ print('FGC forward pass took:', time.time() - t)
+
+
+def test_fast_graph_conv_03():
+ adj_mat = [
+ [ 0, 0, 1, 0, 1 ],
+ [ 0, 1, 0, 1, 0 ],
+ [ 1, 0, 1, 0, 0 ]
+ ]
+ in_repr = torch.rand(5, 32)
+ adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
+ fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse() ])
+ out_repr = fgc(in_repr)
+ assert out_repr.shape == (1, 3, 64)
+ assert (torch.mm(adj_mat, torch.mm(in_repr, fgc.weights)).view(1, 3, 64) == out_repr).all()
+
+
+def test_fast_graph_conv_04():
+ adj_mat = [
+ [ 0, 0, 1, 0, 1 ],
+ [ 0, 1, 0, 1, 0 ],
+ [ 1, 0, 1, 0, 0 ]
+ ]
+ in_repr = torch.rand(5, 32)
+ adj_mat = torch.tensor(adj_mat, dtype=torch.float32)
+ fgc = FastGraphConv(32, 64, [ adj_mat.to_sparse(), adj_mat.to_sparse() ])
+ out_repr = fgc(in_repr)
+ assert out_repr.shape == (2, 3, 64)
+ adj_mat_1 = torch.zeros(adj_mat.shape[0] * 2, adj_mat.shape[1] * 2)
+ adj_mat_1[0:3, 0:5] = adj_mat
+ adj_mat_1[3:6, 5:10] = adj_mat
+ res = torch.mm(in_repr, fgc.weights)
+ res = torch.split(res, res.shape[1] // 2, dim=1)
+ res = torch.cat(res)
+ res = torch.mm(adj_mat_1, res)
+ assert (res.view(2, 3, 64) == out_repr).all()
+
+
+def test_fast_conv_layer_01():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+
+ d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
+ seq_1 = torch.nn.Sequential(in_layer, d_layer)
+ _ = seq_1(None)
+
+ conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
+ seq_2 = torch.nn.Sequential(in_layer, conv_layer)
+ _ = seq_2(None)
+
+
+def test_fast_conv_layer_02():
+ d = _some_data_with_interactions()
+ in_layer = OneHotInputLayer(d)
+
+ d_layer = DecagonLayer(in_layer.output_dim, [32, 32], d)
+ seq_1 = torch.nn.Sequential(in_layer, d_layer)
+ out_repr_1 = seq_1(None)
+
+ assert len(d_layer.next_layer_repr[0]) == 2
+ assert len(d_layer.next_layer_repr[1]) == 2
+
+ conv_layer = FastConvLayer(in_layer.output_dim, [32, 32], d)
+ assert len(conv_layer.next_layer_repr[1]) == 2
+ conv_layer.next_layer_repr[1][0].weights = torch.cat([
+ d_layer.next_layer_repr[1][0].convolutions[0].graph_conv.weight,
+ ], dim=1)
+ conv_layer.next_layer_repr[1][1].weights = torch.cat([
+ d_layer.next_layer_repr[1][1].convolutions[0].graph_conv.weight,
+ d_layer.next_layer_repr[1][1].convolutions[1].graph_conv.weight,
+ d_layer.next_layer_repr[1][1].convolutions[2].graph_conv.weight,
+ ], dim=1)
+ assert len(conv_layer.next_layer_repr[0]) == 2
+ conv_layer.next_layer_repr[0][0].weights = torch.cat([
+ d_layer.next_layer_repr[0][0].convolutions[0].graph_conv.weight,
+ ], dim=1)
+ conv_layer.next_layer_repr[0][1].weights = torch.cat([
+ d_layer.next_layer_repr[0][1].convolutions[0].graph_conv.weight,
+ ], dim=1)
+
+ seq_2 = torch.nn.Sequential(in_layer, conv_layer)
+ out_repr_2 = seq_2(None)
+
+ assert len(out_repr_1) == len(out_repr_2)
+ for i in range(len(out_repr_1)):
+ assert torch.all(out_repr_1[i] == out_repr_2[i])
diff --git a/tests/icosagon/test_fastloop.py b/tests/icosagon/test_fastloop.py
new file mode 100644
index 0000000..0afb285
--- /dev/null
+++ b/tests/icosagon/test_fastloop.py
@@ -0,0 +1,51 @@
+from icosagon.fastloop import FastBatcher, \
+ FastModel
+from icosagon.data import Data
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+import torch
+
+
+def test_fast_batcher_01():
+ d = Data()
+ d.add_node_type('Gene', 5)
+ d.add_node_type('Drug', 3)
+
+ fam = d.add_relation_family('Gene-Drug', 0, 1, True)
+
+ adj_mat = torch.tensor([
+ [ 1, 0, 1 ],
+ [ 0, 0, 1 ],
+ [ 0, 1, 0 ],
+ [ 1, 0, 0 ],
+ [ 0, 1, 1 ]
+ ], dtype=torch.float32).to_sparse()
+ fam.add_relation_type('Target', adj_mat)
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+ # print(prep_d.relation_families[0])
+
+ g = torch.Generator()
+ batcher = FastBatcher(prep_d, batch_size=3, shuffle=True,
+ generator=g, part_type='train')
+
+ print(batcher.edges)
+ print(batcher.targets)
+
+ edges_check = [ set() for _ in range(len(batcher.edges)) ]
+
+ for fam_idx, edges, targets in batcher:
+ print(fam_idx, edges, targets)
+ for e in edges:
+ edges_check[fam_idx].add(tuple(e.tolist()))
+
+ edges_check_2 = [ set() for _ in range(len(batcher.edges)) ]
+ for i, edges in enumerate(batcher.edges):
+ for e in edges:
+ edges_check_2[i].add(tuple(e.tolist()))
+
+ assert edges_check == edges_check_2
+
+
+def test_fast_model_01():
+ raise NotImplementedError
diff --git a/tests/icosagon/test_fastmodel.py b/tests/icosagon/test_fastmodel.py
new file mode 100644
index 0000000..6d5948a
--- /dev/null
+++ b/tests/icosagon/test_fastmodel.py
@@ -0,0 +1,50 @@
+from icosagon.fastmodel import FastModel
+from icosagon.data import Data
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+import torch
+import time
+
+
+def _make_symmetric(x: torch.Tensor):
+ x = (x + x.transpose(0, 1)) / 2
+ return x
+
+
+def _symmetric_random(n_rows, n_columns):
+ return _make_symmetric(torch.rand((n_rows, n_columns),
+ dtype=torch.float32).round().to_sparse())
+
+
+def _some_data_with_interactions():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ fam.add_relation_type('Target',
+ torch.rand((100, 1000), dtype=torch.float32).round().to_sparse())
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, True)
+ fam.add_relation_type('Interaction',
+ _symmetric_random(1000, 1000))
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, True)
+ for i in range(500):
+ fam.add_relation_type('Side Effect: Nausea %d' % i,
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Infertility %d' % i,
+ _symmetric_random(100, 100))
+ fam.add_relation_type('Side Effect: Death %d' % i,
+ _symmetric_random(100, 100))
+ return d
+
+
+def test_fast_model_01():
+ d = _some_data_with_interactions()
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+ model = FastModel(prep_d)
+ for i in range(10):
+ t = time.time()
+ _ = model(None)
+ print('Model forward took:', time.time() - t)
diff --git a/tests/icosagon/test_input.py b/tests/icosagon/test_input.py
new file mode 100644
index 0000000..1ea676a
--- /dev/null
+++ b/tests/icosagon/test_input.py
@@ -0,0 +1,122 @@
+from icosagon.input import InputLayer, \
+ OneHotInputLayer
+from icosagon.data import Data
+import torch
+import pytest
+
+
+def _some_data():
+ d = Data()
+ d.add_node_type('Gene', 1000)
+ d.add_node_type('Drug', 100)
+
+ fam = d.add_relation_family('Drug-Gene', 1, 0, True)
+ fam.add_relation_type('Target', torch.rand(100, 1000))
+
+ fam = d.add_relation_family('Gene-Gene', 0, 0, False)
+ fam.add_relation_type('Interaction', torch.rand(1000, 1000))
+
+ fam = d.add_relation_family('Drug-Drug', 1, 1, False)
+ fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100))
+ fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100))
+ fam.add_relation_type('Side Effect: Death', torch.rand(100, 100))
+ return d
+
+
+def test_input_layer_01():
+ d = _some_data()
+ for output_dim in [32, 64, 128]:
+ layer = InputLayer(d, output_dim)
+ assert layer.output_dim[0] == output_dim
+ assert len(layer.node_reps) == 2
+ assert layer.node_reps[0].shape == (1000, output_dim)
+ assert layer.node_reps[1].shape == (100, output_dim)
+ assert layer.data == d
+
+
+def test_input_layer_02():
+ d = _some_data()
+ layer = InputLayer(d, 32)
+ res = layer(None)
+ assert isinstance(res[0], torch.Tensor)
+ assert isinstance(res[1], torch.Tensor)
+ assert res[0].shape == (1000, 32)
+ assert res[1].shape == (100, 32)
+ assert torch.all(res[0] == layer.node_reps[0])
+ assert torch.all(res[1] == layer.node_reps[1])
+
+
+def test_input_layer_03():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA devices on this host')
+ d = _some_data()
+ layer = InputLayer(d, 32)
+ device = torch.device('cuda:0')
+ layer = layer.to(device)
+ print(list(layer.parameters()))
+ # assert layer.device.type == 'cuda:0'
+ assert layer.node_reps[0].device == device
+ assert layer.node_reps[1].device == device
+
+
+def test_input_layer_04():
+ d = _some_data()
+ layer = InputLayer(d, 32)
+ s = repr(layer)
+ assert s.startswith('Icosagon input layer')
+
+
+def test_one_hot_input_layer_01():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ assert layer.output_dim == [1000, 100]
+ assert len(layer.node_reps) == 2
+ assert layer.node_reps[0].shape == (1000, 1000)
+ assert layer.node_reps[1].shape == (100, 100)
+ assert layer.data == d
+ assert layer.is_sparse
+
+
+def test_one_hot_input_layer_02():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ res = layer(None)
+ assert isinstance(res[0], torch.Tensor)
+ assert isinstance(res[1], torch.Tensor)
+ assert res[0].shape == (1000, 1000)
+ assert res[1].shape == (100, 100)
+ assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
+ assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
+
+
+def test_one_hot_input_layer_03():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('No CUDA devices on this host')
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ device = torch.device('cuda:0')
+ layer = layer.to(device)
+ print(list(layer.parameters()))
+ # assert layer.device.type == 'cuda:0'
+ assert layer.node_reps[0].device == device
+ assert layer.node_reps[1].device == device
+
+
+def test_one_hot_input_layer_04():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ s = repr(layer)
+ assert s.startswith('Icosagon one-hot input layer')
+
+
+def test_one_hot_input_layer_parameter_count_01():
+ d = _some_data()
+ layer = OneHotInputLayer(d)
+ assert len(list(layer.parameters())) == 2
+
+
+def test_input_layer_parameter_count_01():
+ d = _some_data()
+ for output_dim in [32, 64, 128]:
+ layer = InputLayer(d, output_dim)
+ assert len(list(layer.parameters())) == 2
diff --git a/tests/icosagon/test_model.py b/tests/icosagon/test_model.py
new file mode 100644
index 0000000..5e5d482
--- /dev/null
+++ b/tests/icosagon/test_model.py
@@ -0,0 +1,204 @@
+from icosagon.data import Data, \
+ _equal
+from icosagon.model import Model
+from icosagon.trainprep import PreparedData, \
+ PreparedRelationFamily, \
+ PreparedRelationType, \
+ TrainValTest, \
+ norm_adj_mat_one_node_type, \
+ prepare_training
+import torch
+from icosagon.input import OneHotInputLayer
+from icosagon.convlayer import DecagonLayer
+from icosagon.declayer import DecodeLayer
+import pytest
+
+
+def _is_identity_function(f):
+ for x in range(-100, 101):
+ if f(x) != x:
+ return False
+ return True
+
+
+def test_model_01():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ m = Model(prep_d)
+
+ assert m.prep_d == prep_d
+ assert m.layer_dimensions == [32, 64]
+ assert m.keep_prob == 1.
+ assert _is_identity_function(m.rel_activation)
+ assert m.layer_activation == torch.nn.functional.relu
+ assert _is_identity_function(m.dec_activation)
+ assert isinstance(m.seq, torch.nn.Sequential)
+
+
+def test_model_02():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel', mat)
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert isinstance(m.prep_d, PreparedData)
+ assert isinstance(m.prep_d.relation_families, list)
+ assert len(m.prep_d.relation_families) == 1
+ assert isinstance(m.prep_d.relation_families[0], PreparedRelationFamily)
+ assert len(m.prep_d.relation_families[0].relation_types) == 1
+ assert isinstance(m.prep_d.relation_families[0].relation_types[0], PreparedRelationType)
+ assert m.prep_d.relation_families[0].relation_types[0].adjacency_matrix_backward is None
+ assert torch.all(_equal(m.prep_d.relation_families[0].relation_types[0].adjacency_matrix,
+ norm_adj_mat_one_node_type(mat)))
+
+ assert isinstance(m.seq[0], OneHotInputLayer)
+ assert isinstance(m.seq[1], DecagonLayer)
+ assert isinstance(m.seq[2], DecagonLayer)
+ assert isinstance(m.seq[3], DecodeLayer)
+ assert len(m.seq) == 4
+
+
+def test_model_03():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel', mat)
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert len(list(m.seq[0].parameters())) == 1
+ assert len(list(m.seq[1].parameters())) == 1
+ assert len(list(m.seq[2].parameters())) == 1
+ assert len(list(m.seq[3].parameters())) == 2
+
+
+def test_model_04():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 1', mat)
+ fam.add_relation_type('Dummy Rel 2', mat.clone())
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert len(list(m.seq[0].parameters())) == 1
+ assert len(list(m.seq[1].parameters())) == 2
+ assert len(list(m.seq[2].parameters())) == 2
+ assert len(list(m.seq[3].parameters())) == 3
+
+
+def test_model_05():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 1', mat)
+ fam.add_relation_type('Dummy Rel 2', mat.clone())
+
+ fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 2-1', mat)
+ fam.add_relation_type('Dummy Rel 2-2', mat.clone())
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert len(list(m.seq[0].parameters())) == 1
+ assert len(list(m.seq[1].parameters())) == 4
+ assert len(list(m.seq[2].parameters())) == 4
+ assert len(list(m.seq[3].parameters())) == 6
+
+
+def test_model_06():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ d.add_node_type('Foobar', 20)
+
+ fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
+ mat = torch.rand(10, 20).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 1', mat)
+ fam.add_relation_type('Dummy Rel 2', mat.clone())
+
+ fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 2-1', mat)
+ fam.add_relation_type('Dummy Rel 2-2', mat.clone())
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert len(list(m.seq[0].parameters())) == 2
+ assert len(list(m.seq[1].parameters())) == 6
+ assert len(list(m.seq[2].parameters())) == 6
+ assert len(list(m.seq[3].parameters())) == 6
+
+
+def test_model_07():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ with pytest.raises(TypeError):
+ m = Model(1)
+
+ with pytest.raises(TypeError):
+ m = Model(prep_d, layer_dimensions=1)
+
+ with pytest.raises(TypeError):
+ m = Model(prep_d, ratios=1)
+
+ with pytest.raises(ValueError):
+ m = Model(prep_d, keep_prob='x')
+
+ with pytest.raises(TypeError):
+ m = Model(prep_d, rel_activation='x')
+
+ with pytest.raises(TypeError):
+ m = Model(prep_d, layer_activation='x')
+
+ with pytest.raises(TypeError):
+ m = Model(prep_d, dec_activation='x')
+
+
+def test_model_08():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ d.add_node_type('Foobar', 20)
+
+ fam = d.add_relation_family('Dummy-Foobar', 0, 1, True)
+ mat = torch.rand(10, 20).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 1', mat)
+ fam.add_relation_type('Dummy Rel 2', mat.clone())
+
+ fam = d.add_relation_family('Dummy-Dummy 2', 0, 0, False)
+ mat = torch.rand(10, 10).round().to_sparse()
+ fam.add_relation_type('Dummy Rel 2-1', mat)
+ fam.add_relation_type('Dummy Rel 2-2', mat.clone())
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ m = Model(prep_d)
+
+ assert len(list(m.parameters())) == 20
diff --git a/tests/icosagon/test_normalize.py b/tests/icosagon/test_normalize.py
new file mode 100644
index 0000000..f3de835
--- /dev/null
+++ b/tests/icosagon/test_normalize.py
@@ -0,0 +1,185 @@
+from icosagon.normalize import add_eye_sparse, \
+ norm_adj_mat_one_node_type_sparse, \
+ norm_adj_mat_one_node_type_dense, \
+ norm_adj_mat_one_node_type, \
+ norm_adj_mat_two_node_types_sparse, \
+ norm_adj_mat_two_node_types_dense, \
+ norm_adj_mat_two_node_types
+import decagon_pytorch.normalize
+import torch
+import pytest
+import numpy as np
+from math import sqrt
+
+
+def test_add_eye_sparse_01():
+ adj_mat_dense = torch.rand((10, 10))
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+
+ adj_mat_dense += torch.eye(10)
+ adj_mat_sparse = add_eye_sparse(adj_mat_sparse)
+
+ assert torch.all(adj_mat_sparse.to_dense() == adj_mat_dense)
+
+
+def test_add_eye_sparse_02():
+ adj_mat_dense = torch.rand((10, 20))
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_sparse)
+
+
+def test_add_eye_sparse_03():
+ adj_mat_dense = torch.rand((10, 10))
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_dense)
+
+
+def test_add_eye_sparse_04():
+ adj_mat_dense = np.random.rand(10, 10)
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_dense)
+
+
+def test_norm_adj_mat_one_node_type_sparse_01():
+ adj_mat = torch.rand((10, 10))
+ adj_mat = (adj_mat > .5).to(torch.float32)
+ adj_mat = adj_mat.to_sparse()
+ _ = norm_adj_mat_one_node_type_sparse(adj_mat)
+
+
+def test_norm_adj_mat_one_node_type_sparse_02():
+ adj_mat_dense = torch.rand((10, 10))
+ adj_mat_dense = (adj_mat_dense > .5).to(torch.float32)
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+ adj_mat_sparse = norm_adj_mat_one_node_type_sparse(adj_mat_sparse)
+ adj_mat_dense = norm_adj_mat_one_node_type_dense(adj_mat_dense)
+ assert torch.all(adj_mat_sparse.to_dense() - adj_mat_dense < 0.000001)
+
+
+def test_norm_adj_mat_one_node_type_dense_01():
+ adj_mat = torch.rand((10, 10))
+ adj_mat = (adj_mat > .5)
+ _ = norm_adj_mat_one_node_type_dense(adj_mat)
+
+
+def test_norm_adj_mat_one_node_type_dense_02():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0], # 3
+ [1, 0, 1, 0], # 3
+ [1, 1, 0, 1], # 4
+ [0, 0, 1, 0] # 2
+ # 3 3 4 2
+ ])
+ expect_denom = np.array([
+ [ 3, 3, sqrt(3)*2, sqrt(6) ],
+ [ 3, 3, sqrt(3)*2, sqrt(6) ],
+ [ sqrt(3)*2, sqrt(3)*2, 4, sqrt(2)*2 ],
+ [ sqrt(6), sqrt(6), sqrt(2)*2, 2 ]
+ ], dtype=np.float32)
+ expect = (adj_mat.detach().cpu().numpy().astype(np.float32) + np.eye(4)) / expect_denom
+ # expect = np.array([
+ # [1/3, 1/3, 1/3, 0],
+ # [1/3, 1/3, 1/3, 0],
+ # [1/4, 1/4, 1/4, 1/4],
+ # [0, 0, 1/2, 1/2]
+ # ], dtype=np.float32)
+ res = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
+ res = res.todense().astype(np.float32)
+ print('res:', res)
+ print('expect:', expect)
+ assert np.all(res - expect < 0.000001)
+
+
+def test_norm_adj_mat_one_node_type_dense_03():
+ # adj_mat = torch.rand((10, 10))
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0, 0],
+ [1, 0, 1, 0, 1],
+ [1, 1, 0, .5, .5],
+ [0, 0, .5, 0, 1],
+ [0, 1, .5, 1, 0]
+ ])
+ # adj_mat = (adj_mat > .5)
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
+ adj_mat_ico = norm_adj_mat_one_node_type_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_sparse_01():
+ adj_mat = torch.rand((10, 20))
+ adj_mat = (adj_mat > .5)
+ adj_mat = adj_mat.to_sparse()
+ _ = norm_adj_mat_two_node_types_sparse(adj_mat)
+
+
+def test_norm_adj_mat_two_node_types_sparse_02():
+ adj_mat_dense = torch.rand((10, 20))
+ adj_mat_dense = (adj_mat_dense > .5)
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+ adj_mat_sparse = norm_adj_mat_two_node_types_sparse(adj_mat_sparse)
+ adj_mat_dense = norm_adj_mat_two_node_types_dense(adj_mat_dense)
+ assert torch.all(adj_mat_sparse.to_dense() - adj_mat_dense < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_01():
+ adj_mat = torch.rand((10, 20))
+ adj_mat = (adj_mat > .5)
+ _ = norm_adj_mat_two_node_types_dense(adj_mat)
+
+
+def test_norm_adj_mat_two_node_types_dense_02():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0], # 2
+ [1, 0, 1, 0], # 2
+ [1, 1, 0, 1], # 3
+ [0, 0, 1, 0] # 1
+ # 2 2 3 1
+ ])
+ expect_denom = np.array([
+ [ 2, 2, sqrt(6), sqrt(2) ],
+ [ 2, 2, sqrt(6), sqrt(2) ],
+ [ sqrt(6), sqrt(6), 3, sqrt(3) ],
+ [ sqrt(2), sqrt(2), sqrt(3), 1 ]
+ ], dtype=np.float32)
+ expect = adj_mat.detach().cpu().numpy().astype(np.float32) / expect_denom
+ res = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ res = res.todense().astype(np.float32)
+ print('res:', res)
+ print('expect:', expect)
+ assert np.all(res - expect < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_03():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0, 0],
+ [1, 0, 1, 0, 1],
+ [1, 1, 0, .5, .5],
+ [0, 0, .5, 0, 1],
+ [0, 1, .5, 1, 0]
+ ])
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ adj_mat_ico = norm_adj_mat_two_node_types_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_04():
+ adj_mat = torch.rand((10, 20))
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ adj_mat_ico = norm_adj_mat_two_node_types_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
diff --git a/tests/icosagon/test_sampling.py b/tests/icosagon/test_sampling.py
new file mode 100644
index 0000000..824541b
--- /dev/null
+++ b/tests/icosagon/test_sampling.py
@@ -0,0 +1,166 @@
+import tensorflow as tf
+import numpy as np
+from collections import defaultdict
+import torch
+import torch.utils.data
+from typing import List, \
+ Union
+import icosagon.sampling
+import scipy.stats
+
+
+def test_unigram_01():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 500
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+ true_classes = tf.convert_to_tensor(true_classes)
+
+ neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ num_true=num_true,
+ num_sampled=batch_size,
+ unique=False,
+ range_max=range_max,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ assert neg_samples.shape == (batch_size,)
+
+ for i in range(batch_size):
+ assert neg_samples[i] != true_classes[i, 0]
+
+ counts = defaultdict(int)
+ with tf.Session() as sess:
+ neg_samples = neg_samples.eval()
+ for x in neg_samples:
+ counts[x] += 1
+
+ print('counts:', counts)
+
+ assert counts[0] < counts[1] and \
+ counts[0] < counts[2] and \
+ counts[0] < counts[4] and \
+ counts[0] < counts[6]
+
+ assert counts[2] < counts[1] and \
+ counts[0] < counts[6]
+
+ assert counts[3] < counts[1] and \
+ counts[3] < counts[2] and \
+ counts[3] < counts[4] and \
+ counts[3] < counts[6]
+
+ assert counts[4] < counts[1] and \
+ counts[4] < counts[6]
+
+ assert counts[5] < counts[1] and \
+ counts[5] < counts[2] and \
+ counts[5] < counts[4] and \
+ counts[5] < counts[6]
+
+
+def test_unigram_02():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 500
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+ true_classes = torch.tensor(true_classes)
+
+ neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ unigrams=unigrams,
+ distortion=distortion)
+
+ assert neg_samples.shape == (batch_size,)
+
+ for i in range(batch_size):
+ assert neg_samples[i] != true_classes[i, 0]
+
+ counts = defaultdict(int)
+ for x in neg_samples:
+ counts[x.item()] += 1
+
+ print('counts:', counts)
+
+ assert counts[0] < counts[1] and \
+ counts[0] < counts[2] and \
+ counts[0] < counts[4] and \
+ counts[0] < counts[6]
+
+ assert counts[2] < counts[1] and \
+ counts[0] < counts[6]
+
+ assert counts[3] < counts[1] and \
+ counts[3] < counts[2] and \
+ counts[3] < counts[4] and \
+ counts[3] < counts[6]
+
+ assert counts[4] < counts[1] and \
+ counts[4] < counts[6]
+
+ assert counts[5] < counts[1] and \
+ counts[5] < counts[2] and \
+ counts[5] < counts[4] and \
+ counts[5] < counts[6]
+
+
+def test_unigram_03():
+ range_max = 7
+ distortion = 0.75
+ batch_size = 2500
+ unigrams = [ 1, 3, 2, 1, 2, 1, 3]
+ num_true = 1
+
+ true_classes = np.zeros((batch_size, num_true), dtype=np.int64)
+ for i in range(batch_size):
+ true_classes[i, 0] = i % range_max
+
+ true_classes_tf = tf.convert_to_tensor(true_classes)
+ true_classes_torch = torch.tensor(true_classes)
+
+ counts_tf = torch.zeros(range_max)
+ counts_torch = torch.zeros(range_max)
+
+ for i in range(10):
+ neg_samples, _, _ = tf.nn.fixed_unigram_candidate_sampler(
+ true_classes=true_classes_tf,
+ num_true=num_true,
+ num_sampled=batch_size,
+ unique=False,
+ range_max=range_max,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ counts = torch.zeros(range_max)
+ with tf.Session() as sess:
+ neg_samples = neg_samples.eval()
+ for x in neg_samples:
+ counts[x.item()] += 1
+ counts_tf += counts
+
+ neg_samples = icosagon.sampling.fixed_unigram_candidate_sampler(
+ true_classes=true_classes,
+ distortion=distortion,
+ unigrams=unigrams)
+
+ counts = torch.zeros(range_max)
+ for x in neg_samples:
+ counts[x.item()] += 1
+ counts_torch += counts
+
+ print('counts_tf:', counts_tf)
+ print('counts_torch:', counts_torch)
+
+ distance = scipy.stats.wasserstein_distance(counts_tf, counts_torch)
+ assert distance < 2000
diff --git a/tests/icosagon/test_trainloop.py b/tests/icosagon/test_trainloop.py
new file mode 100644
index 0000000..192cdf9
--- /dev/null
+++ b/tests/icosagon/test_trainloop.py
@@ -0,0 +1,184 @@
+from icosagon.data import Data, \
+ _equal
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+from icosagon.model import Model
+from icosagon.trainloop import TrainLoop
+import torch
+import pytest
+import pdb
+import time
+
+
+def test_train_loop_01():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ m = Model(prep_d)
+
+ loop = TrainLoop(m)
+
+ assert loop.model == m
+ assert loop.lr == 0.001
+ assert loop.loss == torch.nn.functional.binary_cross_entropy_with_logits
+ assert loop.batch_size == 100
+
+
+def test_train_loop_02():
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.rand(10, 10).round())
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ m = Model(prep_d)
+
+ for prm in m.parameters():
+ print(prm.shape, prm.is_leaf, prm.requires_grad)
+
+ loop = TrainLoop(m)
+
+ loop.run_epoch()
+
+ for prm in m.parameters():
+ print(prm.shape, prm.is_leaf, prm.requires_grad)
+
+
+def test_train_loop_03():
+ # pdb.set_trace()
+ if torch.cuda.device_count() == 0:
+ pytest.skip('CUDA required for this test')
+
+ adj_mat = torch.rand(10, 10).round()
+ dev = torch.device('cuda:0')
+ adj_mat = adj_mat.to(dev)
+
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', adj_mat)
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+ # pdb.set_trace()
+
+ m = Model(prep_d)
+ m = m.to(dev)
+
+ print(list(m.parameters()))
+
+ for prm in m.parameters():
+ assert prm.device == dev
+
+ loop = TrainLoop(m)
+
+ loop.run_epoch()
+
+
+def test_train_loop_04():
+ adj_mat = torch.rand(10, 10).round()
+
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', adj_mat)
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ m = Model(prep_d)
+
+ old_values = []
+ for prm in m.parameters():
+ old_values.append(prm.clone().detach())
+
+ loop = TrainLoop(m)
+
+ loop.run_epoch()
+
+ for i, prm in enumerate(m.parameters()):
+ assert not prm.requires_grad or \
+ not torch.all(_equal(prm, old_values[i]))
+
+
+def test_train_loop_05():
+ adj_mat = torch.rand(10, 10).round().to_sparse()
+
+ d = Data()
+ d.add_node_type('Dummy', 10)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', adj_mat)
+
+ prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
+
+ m = Model(prep_d)
+
+ old_values = []
+ for prm in m.parameters():
+ old_values.append(prm.clone().detach())
+
+ loop = TrainLoop(m)
+
+ loop.run_epoch()
+
+ for i, prm in enumerate(m.parameters()):
+ assert not prm.requires_grad or \
+ not torch.all(_equal(prm, old_values[i]))
+
+
+def test_timing_01():
+ adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse()
+ rep = torch.eye(2000).requires_grad_(True)
+ t = time.time()
+ for _ in range(1300):
+ _ = torch.sparse.mm(adj_mat, rep)
+ print('Elapsed:', time.time() - t)
+
+
+def test_timing_02():
+ adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
+ adj_mat_batch = [adj_mat.view(1, 2000, 2000)] * 1300
+ adj_mat_batch = torch.cat(adj_mat_batch)
+ rep = torch.eye(2000).requires_grad_(True)
+ t = time.time()
+ res = torch.matmul(adj_mat_batch, rep)
+ print('Elapsed:', time.time() - t)
+ print('res.shape:', res.shape)
+
+
+def test_timing_03():
+ adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32)
+ adj_mat_batch = [adj_mat.view(1, 2000, 2000).to_sparse()] * 1300
+ adj_mat_batch = torch.cat(adj_mat_batch)
+ rep = torch.eye(2000).requires_grad_(True)
+ rep_batch = [rep.view(1, 2000, 2000)] * 1300
+ rep_batch = torch.cat(rep_batch)
+ t = time.time()
+ with pytest.raises(RuntimeError):
+ _ = torch.bmm(adj_mat_batch, rep)
+ print('Elapsed:', time.time() - t)
+
+
+def test_timing_04():
+ adj_mat = (torch.rand(2000, 2000) < .0001).to(torch.float32).to_sparse()
+ rep = torch.eye(2000).requires_grad_(True)
+ t = time.time()
+ for _ in range(1300):
+ _ = torch.sparse.mm(adj_mat, rep)
+ print('Elapsed:', time.time() - t)
+
+
+def test_timing_05():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('Test requires CUDA')
+ dev = torch.device('cuda:0')
+ adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse().to(dev)
+ rep = torch.eye(2000).requires_grad_(True).to(dev)
+ t = time.time()
+ for _ in range(1300):
+ _ = torch.sparse.mm(adj_mat, rep)
+ torch.cuda.synchronize()
+ print('Elapsed:', time.time() - t)
diff --git a/tests/icosagon/test_trainprep.py b/tests/icosagon/test_trainprep.py
new file mode 100644
index 0000000..5ec86a7
--- /dev/null
+++ b/tests/icosagon/test_trainprep.py
@@ -0,0 +1,268 @@
+#
+# Copyright (C) Stanislaw Adaszewski, 2020
+# License: GPLv3
+#
+
+
+from icosagon.trainprep import TrainValTest, \
+ train_val_test_split_edges, \
+ get_edges_and_degrees, \
+ prepare_adj_mat, \
+ prepare_relation_type, \
+ prep_rel_one_node_type, \
+ prep_rel_two_node_types_sym, \
+ prep_rel_two_node_types_asym
+import torch
+import pytest
+import numpy as np
+from itertools import chain
+from icosagon.data import RelationType
+import icosagon.trainprep
+
+
+def test_train_val_test_split_edges_01():
+ edges = torch.randint(0, 10, (10, 2))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(edges, None)
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(edges, (.8, .1, .1))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
+ with pytest.raises(ValueError):
+ _ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
+ res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
+ assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
+ res.test.shape == (1, 2)
+ res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
+ assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
+ res.test.shape == (2, 2)
+ res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
+ assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
+ res.test.shape == (0, 2)
+ res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
+ assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
+ res.test.shape == (5, 2)
+ res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
+ assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
+ res.test.shape == (10, 2)
+ res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
+ assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
+ res.test.shape == (0, 2)
+
+
+def test_train_val_test_split_edges_02():
+ edges = torch.randint(0, 30, (30, 2))
+ ratios = TrainValTest(.8, .1, .1)
+ res = train_val_test_split_edges(edges, ratios)
+ edges = [ tuple(a) for a in edges ]
+ res = [ tuple(a) for a in chain(res.train, res.val, res.test) ]
+ assert all([ a in edges for a in res ])
+
+
+def test_get_edges_and_degrees_01():
+ adj_mat_dense = (torch.rand((10, 10)) > .5)
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+ edges_dense, degrees_dense = get_edges_and_degrees(adj_mat_dense)
+ edges_sparse, degrees_sparse = get_edges_and_degrees(adj_mat_sparse)
+ assert torch.all(degrees_dense == degrees_sparse)
+ edges_dense = [ tuple(a) for a in edges_dense ]
+ edges_sparse = [ tuple(a) for a in edges_dense ]
+ assert len(edges_dense) == len(edges_sparse)
+ assert all([ a in edges_dense for a in edges_sparse ])
+ assert all([ a in edges_sparse for a in edges_dense ])
+ # assert torch.all(edges_dense == edges_sparse)
+
+
+def test_prepare_adj_mat_01():
+ adj_mat = (torch.rand((10, 10)) > .5)
+ adj_mat = adj_mat.to_sparse()
+ ratios = TrainValTest(.8, .1, .1)
+ _ = prepare_adj_mat(adj_mat, ratios)
+
+
+def test_prepare_adj_mat_02():
+ adj_mat = (torch.rand((10, 10)) > .5)
+ adj_mat = adj_mat.to_sparse()
+ ratios = TrainValTest(.8, .1, .1)
+ (adj_mat_train, edges_pos, edges_neg) = prepare_adj_mat(adj_mat, ratios)
+ assert isinstance(adj_mat_train, torch.Tensor)
+ assert adj_mat_train.is_sparse
+ assert adj_mat_train.shape == adj_mat.shape
+ assert adj_mat_train.dtype == adj_mat.dtype
+ assert isinstance(edges_pos, TrainValTest)
+ assert isinstance(edges_neg, TrainValTest)
+ for a in ['train', 'val', 'test']:
+ for b in [edges_pos, edges_neg]:
+ edges = getattr(b, a)
+ assert isinstance(edges, torch.Tensor)
+ assert len(edges.shape) == 2
+ assert edges.shape[1] == 2
+
+
+def test_prepare_relation_type_01():
+ adj_mat = (torch.rand((10, 10)) > .5).to(torch.float32)
+ r = RelationType('Test', 0, 0, adj_mat, True)
+ ratios = TrainValTest(.8, .1, .1)
+ _ = prepare_relation_type(r, ratios, False)
+
+
+def test_prep_rel_one_node_type_01():
+ adj_mat = torch.zeros(100)
+ perm = torch.randperm(100)
+ adj_mat[perm[:10]] = 1
+ adj_mat = adj_mat.view(10, 10)
+ rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
+ ratios = TrainValTest(.8, .1, .1)
+ prep_rel = prep_rel_one_node_type(rel, ratios)
+ assert prep_rel.name == rel.name
+ assert prep_rel.node_type_row == rel.node_type_row
+ assert prep_rel.node_type_column == rel.node_type_column
+ assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
+ assert prep_rel.adjacency_matrix_backward is None
+ assert len(prep_rel.edges_pos.train) == 8
+ assert len(prep_rel.edges_pos.val) == 1
+ assert len(prep_rel.edges_pos.test) == 1
+ assert len(prep_rel.edges_neg.train) == 8
+ assert len(prep_rel.edges_neg.val) == 1
+ assert len(prep_rel.edges_neg.test) == 1
+
+ assert len(prep_rel.edges_back_pos.train) == 0
+ assert len(prep_rel.edges_back_pos.val) == 0
+ assert len(prep_rel.edges_back_pos.test) == 0
+ assert len(prep_rel.edges_back_neg.train) == 0
+ assert len(prep_rel.edges_back_neg.val) == 0
+ assert len(prep_rel.edges_back_neg.test) == 0
+
+
+def test_prep_rel_two_node_types_sym_01():
+ adj_mat = torch.zeros(200)
+ perm = torch.randperm(100)
+ adj_mat[perm[:10]] = 1
+ adj_mat = adj_mat.view(10, 20)
+ rel = RelationType('Dummy Relation', 0, 1, adj_mat, None)
+ ratios = TrainValTest(.8, .1, .1)
+ prep_rel = prep_rel_two_node_types_sym(rel, ratios)
+ assert prep_rel.name == rel.name
+ assert prep_rel.node_type_row == rel.node_type_row
+ assert prep_rel.node_type_column == rel.node_type_column
+ assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
+ assert prep_rel.adjacency_matrix_backward.shape == (20, 10)
+ assert len(prep_rel.edges_pos.train) == 8
+ assert len(prep_rel.edges_pos.val) == 1
+ assert len(prep_rel.edges_pos.test) == 1
+ assert len(prep_rel.edges_neg.train) == 8
+ assert len(prep_rel.edges_neg.val) == 1
+ assert len(prep_rel.edges_neg.test) == 1
+
+ assert len(prep_rel.edges_back_pos.train) == 0
+ assert len(prep_rel.edges_back_pos.val) == 0
+ assert len(prep_rel.edges_back_pos.test) == 0
+ assert len(prep_rel.edges_back_neg.train) == 0
+ assert len(prep_rel.edges_back_neg.val) == 0
+ assert len(prep_rel.edges_back_neg.test) == 0
+
+
+def test_prep_rel_two_node_types_asym_01():
+ adj_mat = torch.zeros(200)
+ perm = torch.randperm(100)
+ adj_mat[perm[:10]] = 1
+ adj_mat = adj_mat.view(10, 20)
+
+ adj_mat_back = torch.zeros(200)
+ perm = torch.randperm(100)
+ adj_mat_back[perm[:10]] = 1
+ adj_mat_back = adj_mat_back.view(20, 10)
+
+ rel = RelationType('Dummy Relation', 0, 1, adj_mat, adj_mat_back)
+ ratios = TrainValTest(.8, .1, .1)
+ prep_rel = prep_rel_two_node_types_asym(rel, ratios)
+ assert prep_rel.name == rel.name
+ assert prep_rel.node_type_row == rel.node_type_row
+ assert prep_rel.node_type_column == rel.node_type_column
+ assert prep_rel.adjacency_matrix.shape == rel.adjacency_matrix.shape
+ assert prep_rel.adjacency_matrix_backward.shape == rel.adjacency_matrix_backward.shape
+ assert len(prep_rel.edges_pos.train) == 8
+ assert len(prep_rel.edges_pos.val) == 1
+ assert len(prep_rel.edges_pos.test) == 1
+ assert len(prep_rel.edges_neg.train) == 8
+ assert len(prep_rel.edges_neg.val) == 1
+ assert len(prep_rel.edges_neg.test) == 1
+
+ assert len(prep_rel.edges_back_pos.train) == 8
+ assert len(prep_rel.edges_back_pos.val) == 1
+ assert len(prep_rel.edges_back_pos.test) == 1
+ assert len(prep_rel.edges_back_neg.train) == 8
+ assert len(prep_rel.edges_back_neg.val) == 1
+ assert len(prep_rel.edges_back_neg.test) == 1
+
+
+def test_prepare_relation_type_02():
+ with pytest.raises(ValueError):
+ prepare_relation_type(None, None, True)
+
+ adj_mat = torch.rand(10, 10).round()
+ rel = RelationType('Dummy Relation', 0, 0, adj_mat, None)
+ with pytest.raises(ValueError):
+ prepare_relation_type(rel, None, True)
+
+ ratios = TrainValTest(.8, .1, .1)
+ with pytest.raises(ValueError):
+ prepare_relation_type(None, ratios, True)
+
+ _ = prepare_relation_type(rel, ratios, True)
+
+
+def test_prepare_relation_type_03(monkeypatch):
+ a = 0
+ b = 0
+ c = 0
+ def fake_prep_rel_one_node_type(*args, **kwargs):
+ nonlocal a
+ a += 1
+ def fake_prep_rel_two_node_types_sym(*args, **kwargs):
+ nonlocal b
+ b += 1
+ def fake_prep_rel_two_node_types_asym(*args, **kwargs):
+ nonlocal c
+ c += 1
+ monkeypatch.setattr(icosagon.trainprep, 'prep_rel_one_node_type',
+ fake_prep_rel_one_node_type)
+ monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_sym',
+ fake_prep_rel_two_node_types_sym)
+ monkeypatch.setattr(icosagon.trainprep, 'prep_rel_two_node_types_asym',
+ fake_prep_rel_two_node_types_asym)
+ ratios = TrainValTest(.8, .1, .1)
+ rel = RelationType('Dummy Relation', 0, 0, None, None)
+ prepare_relation_type(rel, ratios, False)
+ assert a == 1
+ rel = RelationType('Dummy Relation', 0, 0, None, None)
+ prepare_relation_type(rel, ratios, True)
+ assert a == 2
+ rel = RelationType('Dummy Relation', 0, 1, None, None)
+ prepare_relation_type(rel, ratios, True)
+ assert b == 1
+ rel = RelationType('Dummy Relation', 0, 1, None, None)
+ prepare_relation_type(rel, ratios, False)
+ assert c == 1
+ assert a == 2 and b == 1 and c == 1
+
+
+# def prepare_relation(r, ratios):
+# adj_mat = r.adjacency_matrix
+# adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat)
+#
+# if r.node_type_row == r.node_type_column:
+# adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
+# else:
+# adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
+#
+# return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
+# adj_mat_train, edges_pos, edges_neg)
diff --git a/tests/icosagon/test_weights.py b/tests/icosagon/test_weights.py
new file mode 100644
index 0000000..5ddb997
--- /dev/null
+++ b/tests/icosagon/test_weights.py
@@ -0,0 +1,23 @@
+from icosagon.weights import init_glorot
+import torch
+import numpy as np
+
+
+def test_init_glorot_01():
+ torch.random.manual_seed(0)
+ res = init_glorot(10, 20)
+ torch.random.manual_seed(0)
+ rnd = torch.rand((10, 20))
+ init_range = np.sqrt(6.0 / 30)
+ expected = -init_range + 2 * init_range * rnd
+ assert torch.all(res == expected)
+
+
+def test_init_glorot_02():
+ torch.random.manual_seed(0)
+ res = init_glorot(20, 10)
+ torch.random.manual_seed(0)
+ rnd = torch.rand((20, 10))
+ init_range = np.sqrt(6.0 / 30)
+ expected = -init_range + 2 * init_range * rnd
+ assert torch.all(res == expected)
diff --git a/tests/icosagon/unused/_test_loss.py b/tests/icosagon/unused/_test_loss.py
new file mode 100644
index 0000000..4f1edde
--- /dev/null
+++ b/tests/icosagon/unused/_test_loss.py
@@ -0,0 +1,46 @@
+from icosagon.loss import CrossEntropyLoss
+from icosagon.declayer import Predictions, \
+ RelationFamilyPredictions, \
+ RelationPredictions
+from icosagon.data import Data
+from icosagon.trainprep import prepare_training, \
+ TrainValTest
+import torch
+
+
+def test_cross_entropy_loss_01():
+ d = Data()
+ d.add_node_type('Dummy', 5)
+ fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
+ fam.add_relation_type('Dummy Rel', torch.tensor([
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0]
+ ], dtype=torch.float32))
+
+ prep_d = prepare_training(d, TrainValTest(1., 0., 0.))
+
+ assert len(prep_d.relation_families) == 1
+ assert len(prep_d.relation_families[0].relation_types) == 1
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.train) == 5
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.val) == 0
+ assert len(prep_d.relation_families[0].relation_types[0].edges_pos.test) == 0
+
+ rel_pred = RelationPredictions(
+ TrainValTest(torch.tensor([1, 0, 1, 0, 1], dtype=torch.float32), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0)),
+ TrainValTest(torch.zeros(0), torch.zeros(0), torch.zeros(0))
+ )
+ fam_pred = RelationFamilyPredictions([ rel_pred ])
+ pred = Predictions([ fam_pred ])
+
+ loss = CrossEntropyLoss(prep_d)
+ print('loss: %.7f' % loss(pred))
+ assert torch.abs(loss(pred) - 55.262043) < 0.000001
+
+ loss = CrossEntropyLoss(prep_d, reduction='mean')
+ print('loss: %.7f' % loss(pred))
+ assert torch.abs(loss(pred) - 11.0524082) < 0.000001
diff --git a/tests/triacontagon/test_batch.py b/tests/triacontagon/test_batch.py
new file mode 100644
index 0000000..9591c45
--- /dev/null
+++ b/tests/triacontagon/test_batch.py
@@ -0,0 +1,316 @@
+from triacontagon.batch import _same_data_org, \
+ DualBatcher, \
+ Batcher
+from triacontagon.data import Data
+from triacontagon.decode import dedicom_decoder
+import torch
+
+
+def test_same_data_org_01():
+ data = Data()
+ assert _same_data_org(data, data)
+
+ data.add_vertex_type('Foo', 10)
+ assert _same_data_org(data, data)
+
+ data.add_vertex_type('Bar', 10)
+ assert _same_data_org(data, data)
+
+ data_1 = Data()
+ assert not _same_data_org(data, data_1)
+
+ data_1.add_vertex_type('Foo', 10)
+ assert not _same_data_org(data, data_1)
+
+ data_1.add_vertex_type('Bar', 10)
+ assert _same_data_org(data, data_1)
+
+
+def test_same_data_org_02():
+ data = Data()
+ data.add_vertex_type('Foo', 4)
+ data.add_edge_type('Foo-Foo', 0, 0, [
+ torch.tensor([
+ [0, 0, 0, 1],
+ [1, 0, 0, 0],
+ [0, 1, 0, 1],
+ [1, 0, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ assert _same_data_org(data, data)
+
+ data_1 = Data()
+ data_1.add_vertex_type('Foo', 4)
+ data_1.add_edge_type('Foo-Foo', 0, 0, [
+ torch.tensor([
+ [0, 0, 0, 1],
+ [1, 0, 0, 0],
+ [0, 1, 0, 1],
+ [1, 0, 0, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ assert not _same_data_org(data, data_1)
+
+
+def test_batcher_01():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = Batcher(d, batch_size=1)
+
+ visited = set()
+ for t in b:
+ print(t)
+ k = tuple(t.edges[0].tolist())
+ visited.add(k)
+
+ assert visited == { (0, 1), (0, 3),
+ (1, 4), (2, 0), (3, 2), (4, 3) }
+
+
+def test_batcher_02():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse(),
+
+ torch.tensor([
+ [0, 0, 1, 0, 1],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = Batcher(d, batch_size=1)
+
+ visited = set()
+ for t in b:
+ print(t)
+ k = (t.relation_type_index,) + \
+ tuple(t.edges[0].tolist())
+ visited.add(k)
+
+ assert visited == { (0, 0, 1), (0, 0, 3),
+ (0, 1, 4), (0, 2, 0), (0, 3, 2), (0, 4, 3),
+ (1, 0, 2), (1, 0, 4), (1, 1, 3), (1, 2, 4),
+ (1, 3, 1), (1, 4, 2) }
+
+
+def test_batcher_03():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+ d.add_vertex_type('Drug', 4)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse(),
+
+ torch.tensor([
+ [0, 0, 1, 0, 1],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ d.add_edge_type('Gene-Drug', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0],
+ [1, 0, 0, 1],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 1, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = Batcher(d, batch_size=1)
+
+ visited = set()
+ for t in b:
+ print(t)
+ k = (t.vertex_type_row, t.vertex_type_column,
+ t.relation_type_index,) + \
+ tuple(t.edges[0].tolist())
+ visited.add(k)
+
+ assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
+ (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
+ (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
+ (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
+ (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
+ (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
+ (0, 1, 0, 4, 2) }
+
+
+def test_batcher_04():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = Batcher(d, batch_size=3)
+
+ visited = set()
+ for t in b:
+ print(t)
+ for e in t.edges:
+ k = tuple(e.tolist())
+ visited.add(k)
+
+ assert visited == { (0, 1), (0, 3),
+ (1, 4), (2, 0), (3, 2), (4, 3) }
+
+
+def test_batcher_05():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+ d.add_vertex_type('Drug', 4)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse(),
+
+ torch.tensor([
+ [0, 0, 1, 0, 1],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ d.add_edge_type('Gene-Drug', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0],
+ [1, 0, 0, 1],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 1, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = Batcher(d, batch_size=5)
+
+ visited = set()
+ for t in b:
+ print(t)
+ for e in t.edges:
+ k = (t.vertex_type_row, t.vertex_type_column,
+ t.relation_type_index,) + \
+ tuple(e.tolist())
+ visited.add(k)
+
+ assert visited == { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
+ (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
+ (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
+ (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
+ (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
+ (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
+ (0, 1, 0, 4, 2) }
+
+
+def test_dual_batcher_01():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+ d.add_vertex_type('Drug', 4)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0]
+ ]).to_sparse(),
+
+ torch.tensor([
+ [0, 0, 1, 0, 1],
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ d.add_edge_type('Gene-Drug', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0],
+ [1, 0, 0, 1],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0],
+ [0, 1, 1, 0]
+ ]).to_sparse()
+ ], dedicom_decoder)
+
+ b = DualBatcher(d, d, batch_size=5)
+
+ visited_pos = set()
+ visited_neg = set()
+ for t_pos, t_neg in b:
+ assert t_pos.vertex_type_row == t_neg.vertex_type_row
+ assert t_pos.vertex_type_column == t_neg.vertex_type_column
+ assert t_pos.relation_type_index == t_neg.relation_type_index
+ assert len(t_pos.edges) == len(t_neg.edges)
+
+ for e in t_pos.edges:
+ k = (t_pos.vertex_type_row, t_pos.vertex_type_column,
+ t_pos.relation_type_index,) + \
+ tuple(e.tolist())
+ visited_pos.add(k)
+
+ for e in t_neg.edges:
+ k = (t_neg.vertex_type_row, t_neg.vertex_type_column,
+ t_neg.relation_type_index,) + \
+ tuple(e.tolist())
+ visited_neg.add(k)
+
+ expected = { (0, 0, 0, 0, 1), (0, 0, 0, 0, 3),
+ (0, 0, 0, 1, 4), (0, 0, 0, 2, 0), (0, 0, 0, 3, 2), (0, 0, 0, 4, 3),
+ (0, 0, 1, 0, 2), (0, 0, 1, 0, 4), (0, 0, 1, 1, 3), (0, 0, 1, 2, 4),
+ (0, 0, 1, 3, 1), (0, 0, 1, 4, 2),
+ (0, 1, 0, 0, 1), (0, 1, 0, 1, 0), (0, 1, 0, 1, 3),
+ (0, 1, 0, 2, 1), (0, 1, 0, 3, 2), (0, 1, 0, 4, 1),
+ (0, 1, 0, 4, 2) }
+
+ assert visited_pos == expected
+ assert visited_neg == expected
diff --git a/tests/triacontagon/test_cumcount.py b/tests/triacontagon/test_cumcount.py
new file mode 100644
index 0000000..694b46c
--- /dev/null
+++ b/tests/triacontagon/test_cumcount.py
@@ -0,0 +1,27 @@
+from triacontagon.cumcount import dfill, \
+ argunsort, \
+ cumcount
+import torch
+import numpy as np
+
+
+def test_dfill_01():
+ input = torch.tensor([1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5])
+ output = dfill(input)
+ expected = torch.tensor([0, 0, 0, 0, 0, 5, 5, 7, 7, 7, 10, 10, 12, 12])
+ assert torch.all(output == expected)
+
+
+def test_argunsort_01():
+ input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
+ output = np.argsort(input.numpy())
+ output = argunsort(torch.tensor(output))
+ expected = torch.tensor([0, 1, 5, 7, 8, 10, 2, 3, 12, 13, 6, 9, 4, 11])
+ assert torch.all(output == expected)
+
+
+def test_cumcount_01():
+ input = torch.tensor([1, 1, 2, 3, 3, 4, 1, 1, 5, 5, 2, 3, 1, 4])
+ output = cumcount(input)
+ expected = torch.tensor([0, 1, 0, 0, 1, 0, 2, 3, 0, 1, 1, 2, 4, 1])
+ assert torch.all(output == expected)
diff --git a/tests/triacontagon/test_dropout.py b/tests/triacontagon/test_dropout.py
new file mode 100644
index 0000000..abdb04c
--- /dev/null
+++ b/tests/triacontagon/test_dropout.py
@@ -0,0 +1,26 @@
+from triacontagon.dropout import dropout_sparse, \
+ dropout_dense
+import torch
+import numpy as np
+
+
+def test_dropout_01():
+ for i in range(11):
+ torch.random.manual_seed(i)
+ a = torch.rand((5, 10))
+ a[a < .5] = 0
+
+ keep_prob=i/10. + np.finfo(np.float32).eps
+
+ torch.random.manual_seed(i)
+ b = dropout_dense(a, keep_prob=keep_prob)
+
+ torch.random.manual_seed(i)
+ c = dropout_sparse(a.to_sparse(), keep_prob=keep_prob)
+
+ print('keep_prob:', keep_prob)
+ print('a:', a.detach().cpu().numpy())
+ print('b:', b.detach().cpu().numpy())
+ print('c:', c, c.to_dense().detach().cpu().numpy())
+
+ assert torch.all(b == c.to_dense())
diff --git a/tests/triacontagon/test_loop.py b/tests/triacontagon/test_loop.py
new file mode 100644
index 0000000..5f4359a
--- /dev/null
+++ b/tests/triacontagon/test_loop.py
@@ -0,0 +1,134 @@
+from triacontagon.loop import _merge_pos_neg_batches, \
+ TrainLoop
+from triacontagon.model import TrainingBatch, \
+ Model
+from triacontagon.data import Data
+from triacontagon.decode import dedicom_decoder
+from triacontagon.util import common_one_hot_encoding
+from triacontagon.split import split_data
+import torch
+import pytest
+
+
+def test_merge_pos_neg_batches_01():
+ b_1 = TrainingBatch(0, 0, 0, torch.tensor([
+ [0, 1],
+ [2, 3],
+ [4, 5],
+ [5, 6]
+ ]), torch.ones(4))
+ b_2 = TrainingBatch(0, 0, 0, torch.tensor([
+ [1, 6],
+ [3, 5],
+ [5, 2],
+ [4, 1]
+ ]), torch.zeros(4))
+ b = _merge_pos_neg_batches(b_1, b_2)
+
+ assert b.vertex_type_row == 0
+ assert b.vertex_type_column == 0
+ assert b.relation_type_index == 0
+ assert torch.all(b.edges == torch.tensor([
+ [0, 1],
+ [2, 3],
+ [4, 5],
+ [5, 6],
+ [1, 6],
+ [3, 5],
+ [5, 2],
+ [4, 1]
+ ]))
+ assert torch.all(b.target_values == \
+ torch.cat([ torch.ones(4), torch.zeros(4) ]))
+
+
+def test_merge_pos_neg_batches_02():
+ b_1 = TrainingBatch(0, 1, 0, torch.tensor([
+ [0, 1],
+ [2, 3],
+ [4, 5],
+ [5, 6]
+ ]), torch.ones(4))
+ b_2 = TrainingBatch(0, 0, 0, torch.tensor([
+ [1, 6],
+ [3, 5],
+ [5, 2],
+ [4, 1]
+ ]), torch.zeros(4))
+ print(b_1)
+ with pytest.raises(AssertionError):
+ _ = _merge_pos_neg_batches(b_1, b_2)
+
+ b_1.vertex_type_row, b_1.vertex_type_column = \
+ b_1.vertex_type_column, b_1.vertex_type_row
+ print(b_1)
+ with pytest.raises(AssertionError):
+ _ = _merge_pos_neg_batches(b_1, b_2)
+
+ b_1.vertex_type_row, b_1.relation_type_index = \
+ b_1.relation_type_index, b_1.vertex_type_row
+ print(b_1)
+ with pytest.raises(AssertionError):
+ _ = _merge_pos_neg_batches(b_1, b_2)
+
+
+def test_train_loop_01():
+ data = Data()
+ data.add_vertex_type('Foo', 5)
+ data.add_vertex_type('Bar', 4)
+
+ foo_foo = torch.tensor([
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 0],
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0]
+ ], dtype=torch.float32)
+ foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
+
+ foo_bar = torch.tensor([
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1]
+ ], dtype=torch.float32)
+ bar_foo = foo_bar.transpose(0, 1)
+
+ bar_bar = torch.tensor([
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 1],
+ [0, 0, 1, 0],
+ ], dtype=torch.float32)
+ bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
+
+ data.add_edge_type('Foo-Foo', 0, 0, [
+ foo_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Foo-Bar', 0, 1, [
+ foo_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Foo', 1, 0, [
+ bar_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Bar', 1, 1, [
+ bar_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+
+ initial_repr = common_one_hot_encoding([5, 4])
+
+ model = Model(data, [9, 3, 6],
+ keep_prob=1.0,
+ conv_activation=torch.sigmoid,
+ dec_activation=torch.sigmoid)
+
+ train_data, val_data, test_data = split_data(data, (.5, .5, .0) )
+
+ print('val_data:', val_data)
+ print('val_data.vertex_types:', val_data.vertex_types)
+
+ loop = TrainLoop(model, val_data, test_data, initial_repr,
+ max_epochs=1, batch_size=1)
+
+ _ = loop.run()
diff --git a/tests/triacontagon/test_model.py b/tests/triacontagon/test_model.py
new file mode 100644
index 0000000..06c0ece
--- /dev/null
+++ b/tests/triacontagon/test_model.py
@@ -0,0 +1,113 @@
+import torch
+from triacontagon.model import Model, \
+ TrainingBatch, \
+ _per_layer_required_vertices
+from triacontagon.data import Data
+from triacontagon.decode import dedicom_decoder
+from triacontagon.util import common_one_hot_encoding
+
+
+def test_per_layer_required_vertices_01():
+ d = Data()
+ d.add_vertex_type('Gene', 4)
+ d.add_vertex_type('Drug', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
+ [0, 0, 0, 1],
+ [0, 0, 1, 0],
+ [1, 0, 0, 0],
+ [0, 1, 0, 0]
+ ]).to_sparse() ], dedicom_decoder)
+
+ d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
+ [0, 1, 0, 0, 1],
+ [0, 0, 1, 0, 0],
+ [1, 0, 0, 0, 1],
+ [0, 0, 1, 1, 0]
+ ]).to_sparse() ], dedicom_decoder)
+
+ d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
+ [0, 0, 1, 0, 1],
+ [0, 0, 0, 1, 1],
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 0, 1],
+ [1, 1, 0, 1, 0]
+ ]).to_sparse() ], dedicom_decoder)
+
+ batch = TrainingBatch(0, 1, 0, torch.tensor([
+ [0, 1]
+ ]))
+
+ res = _per_layer_required_vertices(d, batch, 5)
+ print('res:', res)
+
+
+def test_model_convolve_01():
+ d = Data()
+ d.add_vertex_type('Gene', 4)
+ d.add_vertex_type('Drug', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [ torch.tensor([
+ [0, 0, 0, 1],
+ [0, 0, 1, 0],
+ [1, 0, 0, 0],
+ [0, 1, 0, 0]
+ ], dtype=torch.float).to_sparse() ], dedicom_decoder)
+
+ d.add_edge_type('Gene-Drug', 0, 1, [ torch.tensor([
+ [0, 1, 0, 0, 1],
+ [0, 0, 1, 0, 0],
+ [1, 0, 0, 0, 1],
+ [0, 0, 1, 1, 0]
+ ], dtype=torch.float).to_sparse() ], dedicom_decoder)
+
+ d.add_edge_type('Drug-Drug', 1, 1, [ torch.tensor([
+ [0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 1, 0]
+ ], dtype=torch.float).to_sparse() ], dedicom_decoder)
+
+ model = Model(d, [9, 32, 64], keep_prob=1.0,
+ conv_activation = torch.sigmoid,
+ dec_activation = torch.sigmoid)
+
+ repr_1 = torch.eye(9)
+ repr_1[4:, 4:] = 0
+ repr_2 = torch.eye(9)
+ repr_2[:4, :4] = 0
+
+ in_layer_repr = [
+ repr_1[:4, :].to_sparse(),
+ repr_2[4:, :].to_sparse()
+ ]
+
+ _ = model.convolve(in_layer_repr)
+
+
+def test_model_decode_01():
+ d = Data()
+ d.add_vertex_type('Gene', 100)
+
+ gene_gene = torch.rand(100, 100).round()
+ gene_gene = gene_gene - torch.diag(torch.diag(gene_gene))
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ gene_gene.to_sparse()
+ ], dedicom_decoder)
+
+ b = TrainingBatch(0, 0, 0, torch.tensor([
+ [0, 1],
+ [10, 51],
+ [50, 60],
+ [70, 90],
+ [98, 99]
+ ]), torch.ones(5))
+
+ in_repr = common_one_hot_encoding([100])
+
+ in_repr = [ in_repr[0].to_dense() ]
+
+ m = Model(d, [100], 1.0, torch.sigmoid, torch.sigmoid)
+
+ _ = m.decode(in_repr, b)
diff --git a/tests/triacontagon/test_normalize.py b/tests/triacontagon/test_normalize.py
new file mode 100644
index 0000000..e8de28f
--- /dev/null
+++ b/tests/triacontagon/test_normalize.py
@@ -0,0 +1,185 @@
+from triacontagon.normalize import add_eye_sparse, \
+ norm_adj_mat_one_node_type_sparse, \
+ norm_adj_mat_one_node_type_dense, \
+ norm_adj_mat_one_node_type, \
+ norm_adj_mat_two_node_types_sparse, \
+ norm_adj_mat_two_node_types_dense, \
+ norm_adj_mat_two_node_types
+import decagon_pytorch.normalize
+import torch
+import pytest
+import numpy as np
+from math import sqrt
+
+
+def test_add_eye_sparse_01():
+ adj_mat_dense = torch.rand((10, 10))
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+
+ adj_mat_dense += torch.eye(10)
+ adj_mat_sparse = add_eye_sparse(adj_mat_sparse)
+
+ assert torch.all(adj_mat_sparse.to_dense() == adj_mat_dense)
+
+
+def test_add_eye_sparse_02():
+ adj_mat_dense = torch.rand((10, 20))
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_sparse)
+
+
+def test_add_eye_sparse_03():
+ adj_mat_dense = torch.rand((10, 10))
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_dense)
+
+
+def test_add_eye_sparse_04():
+ adj_mat_dense = np.random.rand(10, 10)
+
+ with pytest.raises(ValueError):
+ _ = add_eye_sparse(adj_mat_dense)
+
+
+def test_norm_adj_mat_one_node_type_sparse_01():
+ adj_mat = torch.rand((10, 10))
+ adj_mat = (adj_mat > .5).to(torch.float32)
+ adj_mat = adj_mat.to_sparse()
+ _ = norm_adj_mat_one_node_type_sparse(adj_mat)
+
+
+def test_norm_adj_mat_one_node_type_sparse_02():
+ adj_mat_dense = torch.rand((10, 10))
+ adj_mat_dense = (adj_mat_dense > .5).to(torch.float32)
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+ adj_mat_sparse = norm_adj_mat_one_node_type_sparse(adj_mat_sparse)
+ adj_mat_dense = norm_adj_mat_one_node_type_dense(adj_mat_dense)
+ assert torch.all(adj_mat_sparse.to_dense() - adj_mat_dense < 0.000001)
+
+
+def test_norm_adj_mat_one_node_type_dense_01():
+ adj_mat = torch.rand((10, 10))
+ adj_mat = (adj_mat > .5)
+ _ = norm_adj_mat_one_node_type_dense(adj_mat)
+
+
+def test_norm_adj_mat_one_node_type_dense_02():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0], # 3
+ [1, 0, 1, 0], # 3
+ [1, 1, 0, 1], # 4
+ [0, 0, 1, 0] # 2
+ # 3 3 4 2
+ ])
+ expect_denom = np.array([
+ [ 3, 3, sqrt(3)*2, sqrt(6) ],
+ [ 3, 3, sqrt(3)*2, sqrt(6) ],
+ [ sqrt(3)*2, sqrt(3)*2, 4, sqrt(2)*2 ],
+ [ sqrt(6), sqrt(6), sqrt(2)*2, 2 ]
+ ], dtype=np.float32)
+ expect = (adj_mat.detach().cpu().numpy().astype(np.float32) + np.eye(4)) / expect_denom
+ # expect = np.array([
+ # [1/3, 1/3, 1/3, 0],
+ # [1/3, 1/3, 1/3, 0],
+ # [1/4, 1/4, 1/4, 1/4],
+ # [0, 0, 1/2, 1/2]
+ # ], dtype=np.float32)
+ res = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
+ res = res.todense().astype(np.float32)
+ print('res:', res)
+ print('expect:', expect)
+ assert np.all(res - expect < 0.000001)
+
+
+def test_norm_adj_mat_one_node_type_dense_03():
+ # adj_mat = torch.rand((10, 10))
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0, 0],
+ [1, 0, 1, 0, 1],
+ [1, 1, 0, .5, .5],
+ [0, 0, .5, 0, 1],
+ [0, 1, .5, 1, 0]
+ ])
+ # adj_mat = (adj_mat > .5)
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat)
+ adj_mat_ico = norm_adj_mat_one_node_type_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_sparse_01():
+ adj_mat = torch.rand((10, 20))
+ adj_mat = (adj_mat > .5)
+ adj_mat = adj_mat.to_sparse()
+ _ = norm_adj_mat_two_node_types_sparse(adj_mat)
+
+
+def test_norm_adj_mat_two_node_types_sparse_02():
+ adj_mat_dense = torch.rand((10, 20))
+ adj_mat_dense = (adj_mat_dense > .5)
+ adj_mat_sparse = adj_mat_dense.to_sparse()
+ adj_mat_sparse = norm_adj_mat_two_node_types_sparse(adj_mat_sparse)
+ adj_mat_dense = norm_adj_mat_two_node_types_dense(adj_mat_dense)
+ assert torch.all(adj_mat_sparse.to_dense() - adj_mat_dense < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_01():
+ adj_mat = torch.rand((10, 20))
+ adj_mat = (adj_mat > .5)
+ _ = norm_adj_mat_two_node_types_dense(adj_mat)
+
+
+def test_norm_adj_mat_two_node_types_dense_02():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0], # 2
+ [1, 0, 1, 0], # 2
+ [1, 1, 0, 1], # 3
+ [0, 0, 1, 0] # 1
+ # 2 2 3 1
+ ])
+ expect_denom = np.array([
+ [ 2, 2, sqrt(6), sqrt(2) ],
+ [ 2, 2, sqrt(6), sqrt(2) ],
+ [ sqrt(6), sqrt(6), 3, sqrt(3) ],
+ [ sqrt(2), sqrt(2), sqrt(3), 1 ]
+ ], dtype=np.float32)
+ expect = adj_mat.detach().cpu().numpy().astype(np.float32) / expect_denom
+ res = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ res = res.todense().astype(np.float32)
+ print('res:', res)
+ print('expect:', expect)
+ assert np.all(res - expect < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_03():
+ adj_mat = torch.tensor([
+ [0, 1, 1, 0, 0],
+ [1, 0, 1, 0, 1],
+ [1, 1, 0, .5, .5],
+ [0, 0, .5, 0, 1],
+ [0, 1, .5, 1, 0]
+ ])
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ adj_mat_ico = norm_adj_mat_two_node_types_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
+
+
+def test_norm_adj_mat_two_node_types_dense_04():
+ adj_mat = torch.rand((10, 20))
+ adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_two_node_types(adj_mat)
+ adj_mat_ico = norm_adj_mat_two_node_types_dense(adj_mat)
+ adj_mat_dec = adj_mat_dec.todense()
+ adj_mat_ico = adj_mat_ico.detach().cpu().numpy()
+ print('adj_mat_dec:', adj_mat_dec)
+ print('adj_mat_ico:', adj_mat_ico)
+ assert np.all(adj_mat_dec - adj_mat_ico < 0.000001)
diff --git a/tests/triacontagon/test_sampling.py b/tests/triacontagon/test_sampling.py
new file mode 100644
index 0000000..981cf01
--- /dev/null
+++ b/tests/triacontagon/test_sampling.py
@@ -0,0 +1,130 @@
+from triacontagon.data import Data
+from triacontagon.sampling import fixed_unigram_candidate_sampler, \
+ get_true_classes, \
+ negative_sample_adj_mat, \
+ negative_sample_data, \
+ get_edges_and_degrees
+import triacontagon.sampling
+from triacontagon.decode import dedicom_decoder
+import torch
+import time
+import pytest
+
+
+def test_fixed_unigram_candidate_sampler_01():
+ true_classes = torch.tensor([[-1],
+ [-1],
+ [ 3],
+ [ 2],
+ [-1]])
+ num_repeats = torch.tensor([0, 0, 1, 1, 0])
+ unigrams = torch.tensor([0., 0., 1., 1., 0.], dtype=torch.float64)
+ distortion = 0.75
+ res = fixed_unigram_candidate_sampler(true_classes, num_repeats,
+ unigrams, distortion)
+ print('res:', res)
+
+
+def test_fixed_unigram_candidate_sampler_02():
+ foo_bar = torch.tensor([
+ [0, 1, 0, 1],
+ [0, 0, 0, 1],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 1, 1]
+ ], dtype=torch.float32)
+
+ # bar_foo = foo_bar.transpose(0, 1).to_sparse().coalesce()
+ bar_foo = foo_bar.to_sparse().coalesce()
+
+ true_classes, row_count = get_true_classes(bar_foo)
+ print('true_classes:', true_classes)
+ print('row_count:', row_count)
+
+ edges_pos, degrees = get_edges_and_degrees(bar_foo)
+ print('degrees:', degrees)
+
+ res = fixed_unigram_candidate_sampler(true_classes, row_count,
+ degrees, 0.75)
+ print('res:', res)
+
+
+def test_get_true_classes_01():
+ adj_mat = torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [0, 1, 0, 0, 0]
+ ], dtype=torch.float).to_sparse()
+
+ true_classes, row_count = get_true_classes(adj_mat)
+ print('true_classes:', true_classes)
+
+ true_classes = torch.repeat_interleave(true_classes, row_count, dim=0)
+
+ assert torch.all(true_classes == torch.tensor([
+ [1, 3],
+ [1, 3],
+ [4, -1],
+ [0, 1],
+ [0, 1],
+ [2, 4],
+ [2, 4],
+ [1, -1]
+ ]))
+
+
+def test_get_true_classes_02():
+ adj_mat = torch.rand(2000, 2000).round().to_sparse()
+
+ t = time.time()
+ true_classes, row_count = get_true_classes(adj_mat)
+ print('Elapsed:', time.time() - t)
+
+ print('true_classes.shape:', true_classes.shape)
+
+
+def test_negative_sample_adj_mat_01():
+ adj_mat = torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [0, 1, 0, 0, 0]
+ ])
+
+ print('adj_mat:', adj_mat)
+
+ adj_mat_neg = negative_sample_adj_mat(adj_mat.to_sparse())
+
+ print('adj_mat_neg:', adj_mat_neg.to_dense())
+
+
+def test_negative_sample_data_01():
+ d = Data()
+ d.add_vertex_type('Gene', 5)
+
+ d.add_edge_type('Gene-Gene', 0, 0, [
+ torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 0, 1],
+ [1, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [0, 1, 0, 0, 0]
+ ], dtype=torch.float).to_sparse()
+ ], dedicom_decoder)
+
+ d_neg = negative_sample_data(d)
+
+
+def test_fixed_unigram_candidate_sampler_new_01():
+ if 'fixed_unigram_candidate_sampler_new' not in dir(triacontagon.sampling):
+ pytest.skip('fixed_unigram_candidate_sampler_new not found')
+ x = (torch.rand((10, 10)) < .05).to(torch.float32).to_sparse()
+ true_classes, row_count = get_true_classes(x)
+ edges, degrees = get_edges_and_degrees(x)
+ # import pdb
+ # pdb.set_trace()
+ _ = triacontagon.sampling.fixed_unigram_candidate_sampler_new(true_classes,
+ row_count, degrees, 0.75)
diff --git a/tests/triacontagon/test_split.py b/tests/triacontagon/test_split.py
new file mode 100644
index 0000000..dc5459b
--- /dev/null
+++ b/tests/triacontagon/test_split.py
@@ -0,0 +1,249 @@
+from triacontagon.split import split_adj_mat, \
+ split_edge_type, \
+ split_data
+from triacontagon.util import _equal
+from triacontagon.data import EdgeType, \
+ Data
+from triacontagon.decode import dedicom_decoder
+import torch
+
+
+def test_split_adj_mat_01():
+ adj_mat = torch.tensor([
+ [0, 1, 0, 0, 1],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 1, 0],
+ [0, 0, 1, 1, 0]
+ ]).to_sparse()
+
+ (res,) = split_adj_mat(adj_mat, (1.,))
+ assert torch.all(_equal(res, adj_mat))
+
+
+def test_split_adj_mat_02():
+ adj_mat = torch.tensor([
+ [0, 1, 0, 0, 1],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 1, 0],
+ [0, 0, 1, 1, 0]
+ ]).to_sparse()
+
+ a, b = split_adj_mat(adj_mat, ( .5, .5 ))
+ assert torch.all(_equal(a+b, adj_mat))
+
+
+def test_split_adj_mat_03():
+ adj_mat = torch.tensor([
+ [0, 1, 0, 0, 1],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 1, 0],
+ [0, 0, 1, 1, 0]
+ ]).to_sparse()
+
+ a, b, c = split_adj_mat(adj_mat, ( .8, .1, .1 ))
+ print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense())
+
+ assert torch.all(_equal(a+b+c, adj_mat))
+
+
+def test_split_edge_type_01():
+ et = EdgeType('Dummy', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 0, 1],
+ [0, 1, 0, 1, 0]
+ ]).to_sparse()
+ ], None, None)
+
+ res = split_edge_type(et, (1.,))
+
+ assert torch.all(_equal(et.adjacency_matrices[0],
+ res[0].adjacency_matrices[0]))
+
+
+def test_split_edge_type_02():
+ et = EdgeType('Dummy', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 0, 1],
+ [0, 1, 0, 1, 0]
+ ]).to_sparse()
+ ], None, None)
+
+ res = split_edge_type(et, (.5, .5))
+
+ assert torch.all(_equal(et.adjacency_matrices[0],
+ res[0].adjacency_matrices[0] + \
+ res[1].adjacency_matrices[0]))
+
+
+def test_split_edge_type_03():
+ et = EdgeType('Dummy', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 0, 1],
+ [0, 1, 0, 1, 0]
+ ]).to_sparse()
+ ], None, None)
+
+ res = split_edge_type(et, (.4, .4, .2))
+
+ assert torch.all(_equal(et.adjacency_matrices[0],
+ res[0].adjacency_matrices[0] + \
+ res[1].adjacency_matrices[0] + \
+ res[2].adjacency_matrices[0]))
+
+
+def test_split_edge_type_04():
+ et = EdgeType('Dummy', 0, 1, [
+ torch.tensor([
+ [0, 1, 0, 0, 0],
+ [0, 0, 1, 0, 1],
+ [1, 0, 0, 0, 1],
+ [0, 1, 0, 1, 0]
+ ]).to_sparse(),
+
+ torch.tensor([
+ [1, 0, 0, 0, 0],
+ [0, 1, 0, 1, 0],
+ [0, 0, 1, 1, 0],
+ [1, 0, 1, 0, 0]
+ ]).to_sparse()
+ ], None, None)
+
+ res = split_edge_type(et, (.4, .4, .2))
+
+ assert torch.all(_equal(et.adjacency_matrices[0],
+ res[0].adjacency_matrices[0] + \
+ res[1].adjacency_matrices[0] + \
+ res[2].adjacency_matrices[0]))
+
+ assert torch.all(_equal(et.adjacency_matrices[1],
+ res[0].adjacency_matrices[1] + \
+ res[1].adjacency_matrices[1] + \
+ res[2].adjacency_matrices[1]))
+
+
+def test_split_data_01():
+ data = Data()
+ data.add_vertex_type('Foo', 5)
+ data.add_vertex_type('Bar', 4)
+
+ foo_foo = torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 1, 0]
+ ], dtype=torch.float32)
+ foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
+
+ foo_bar = torch.tensor([
+ [0, 1, 0, 1],
+ [0, 0, 0, 1],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 1, 1]
+ ], dtype=torch.float32)
+ bar_foo = foo_bar.transpose(0, 1)
+
+ bar_bar = torch.tensor([
+ [0, 0, 1, 0],
+ [1, 0, 0, 0],
+ [0, 1, 0, 1],
+ [0, 1, 0, 0],
+ ], dtype=torch.float32)
+ bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
+
+ data.add_edge_type('Foo-Foo', 0, 0, [
+ foo_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Foo-Bar', 0, 1, [
+ foo_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Foo', 1, 0, [
+ bar_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Bar', 1, 1, [
+ bar_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+
+ (res,) = split_data(data, (1.,))
+
+ assert torch.all(_equal(res.edge_types[0, 0].adjacency_matrices[0],
+ data.edge_types[0, 0].adjacency_matrices[0]))
+
+ assert torch.all(_equal(res.edge_types[0, 1].adjacency_matrices[0],
+ data.edge_types[0, 1].adjacency_matrices[0]))
+
+ assert torch.all(_equal(res.edge_types[1, 0].adjacency_matrices[0],
+ data.edge_types[1, 0].adjacency_matrices[0]))
+
+ assert torch.all(_equal(res.edge_types[1, 1].adjacency_matrices[0],
+ data.edge_types[1, 1].adjacency_matrices[0]))
+
+
+def test_split_data_02():
+ data = Data()
+ data.add_vertex_type('Foo', 5)
+ data.add_vertex_type('Bar', 4)
+
+ foo_foo = torch.tensor([
+ [0, 1, 0, 1, 0],
+ [0, 0, 0, 1, 0],
+ [0, 1, 0, 0, 1],
+ [0, 1, 0, 0, 0],
+ [1, 0, 0, 1, 0]
+ ], dtype=torch.float32)
+ foo_foo = (foo_foo + foo_foo.transpose(0, 1)) / 2
+
+ foo_bar = torch.tensor([
+ [0, 1, 0, 1],
+ [0, 0, 0, 1],
+ [0, 1, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 1, 1]
+ ], dtype=torch.float32)
+ bar_foo = foo_bar.transpose(0, 1)
+
+ bar_bar = torch.tensor([
+ [0, 0, 1, 0],
+ [1, 0, 0, 0],
+ [0, 1, 0, 1],
+ [0, 1, 0, 0],
+ ], dtype=torch.float32)
+ bar_bar = (bar_bar + bar_bar.transpose(0, 1)) / 2
+
+ data.add_edge_type('Foo-Foo', 0, 0, [
+ foo_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Foo-Bar', 0, 1, [
+ foo_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Foo', 1, 0, [
+ bar_foo.to_sparse().coalesce()
+ ], dedicom_decoder)
+ data.add_edge_type('Bar-Bar', 1, 1, [
+ bar_bar.to_sparse().coalesce()
+ ], dedicom_decoder)
+
+ a, b = split_data(data, (.5,.5))
+
+ assert torch.all(_equal(a.edge_types[0, 0].adjacency_matrices[0] + \
+ b.edge_types[0, 0].adjacency_matrices[0],
+ data.edge_types[0, 0].adjacency_matrices[0]))
+
+ assert torch.all(_equal(a.edge_types[0, 1].adjacency_matrices[0] + \
+ b.edge_types[0, 1].adjacency_matrices[0],
+ data.edge_types[0, 1].adjacency_matrices[0]))
+
+ assert torch.all(_equal(a.edge_types[1, 0].adjacency_matrices[0] + \
+ b.edge_types[1, 0].adjacency_matrices[0],
+ data.edge_types[1, 0].adjacency_matrices[0]))
+
+ assert torch.all(_equal(a.edge_types[1, 1].adjacency_matrices[0] + \
+ b.edge_types[1, 1].adjacency_matrices[0],
+ data.edge_types[1, 1].adjacency_matrices[0]))
diff --git a/tests/triacontagon/test_util.py b/tests/triacontagon/test_util.py
new file mode 100644
index 0000000..3593d1d
--- /dev/null
+++ b/tests/triacontagon/test_util.py
@@ -0,0 +1,135 @@
+from triacontagon.util import \
+ _clear_adjacency_matrix_except_rows, \
+ _sparse_diag_cat, \
+ _equal, \
+ common_one_hot_encoding
+from triacontagon.model import TrainingBatch
+from triacontagon.decode import dedicom_decoder
+from triacontagon.data import Data
+import torch
+import time
+
+
+def test_clear_adjacency_matrix_except_rows_01():
+ adj_mat = torch.tensor([
+ [0, 0, 1, 0, 0],
+ [0, 0, 0, 1, 1],
+ [1, 0, 1, 0, 0],
+ [1, 1, 0, 0, 0]
+ ], dtype=torch.uint8).to_sparse()
+
+ adj_mat = _sparse_diag_cat([ adj_mat, adj_mat ])
+
+ res = _clear_adjacency_matrix_except_rows(adj_mat,
+ torch.tensor([1, 3]), 4, 2)
+
+ res = res.to_dense()
+
+ truth = torch.tensor([
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 1, 1, 0, 0, 0]
+ ], dtype=torch.uint8)
+
+ print('res:', res)
+
+ assert torch.all(res == truth)
+
+
+def test_clear_adjacency_matrix_except_rows_02():
+ adj_mat = torch.rand(6, 10).round().to(torch.uint8)
+
+ t = time.time()
+ res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
+ print('_sparse_diag_cat() took:', time.time() - t)
+
+ t = time.time()
+ res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
+ 6, 130)
+ print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
+
+ adj_mat[0] = adj_mat[2] = adj_mat[4] = \
+ torch.zeros(10)
+ truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 130)
+
+ assert _equal(res, truth).all()
+
+
+def test_clear_adjacency_matrix_except_rows_03():
+ adj_mat = torch.rand(6, 10).round().to(torch.uint8)
+
+ t = time.time()
+ res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+ print('_sparse_diag_cat() took:', time.time() - t)
+
+ t = time.time()
+ res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
+ 6, 1300)
+ print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
+
+ adj_mat[0] = adj_mat[2] = adj_mat[4] = \
+ torch.zeros(10)
+ truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+
+ assert _equal(res, truth).all()
+
+
+def test_clear_adjacency_matrix_except_rows_04():
+ adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8)
+
+ print('adj_mat.to_sparse():', adj_mat.to_sparse())
+
+ t = time.time()
+ res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+ print('_sparse_diag_cat() took:', time.time() - t)
+
+ t = time.time()
+ res = _clear_adjacency_matrix_except_rows(res, torch.tensor([1, 3, 5]),
+ 2000, 1300)
+ print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
+
+ adj_mat[0] = adj_mat[2] = adj_mat[4] = \
+ torch.zeros(2000)
+ adj_mat[6:] = torch.zeros(2000)
+ truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+
+ assert _equal(res, truth).all()
+
+
+def test_clear_adjacency_matrix_except_rows_05():
+ if torch.cuda.device_count() == 0:
+ pytest.skip('Test requires CUDA')
+
+ device = torch.device('cuda:0')
+ adj_mat = (torch.rand(2000, 2000) < 0.001).to(torch.uint8).to(device)
+
+ print('adj_mat.to_sparse():', adj_mat.to_sparse())
+
+ t = time.time()
+ res = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+ print('_sparse_diag_cat() took:', time.time() - t)
+
+ rows = torch.tensor(list(range(512)), device=device)
+
+ t = time.time()
+ res = _clear_adjacency_matrix_except_rows(res, rows,
+ 2000, 1300)
+ print('_clear_adjacency_matrix_except_rows() took:', time.time() - t)
+
+ adj_mat[512:] = torch.zeros(2000)
+ truth = _sparse_diag_cat([ adj_mat.to_sparse() ] * 1300)
+
+ assert _equal(res, truth).all()
+
+
+def test_common_one_hot_encoding_01():
+ in_repr = common_one_hot_encoding([2000, 200])
+
+ ref = torch.eye(2200)
+ assert torch.all(in_repr[0].to_dense() == ref[:2000, :])
+ assert torch.all(in_repr[1].to_dense() == ref[2000:, :])