IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Quellcode durchsuchen

Add simple class to hold graph data.

master
Stanislaw Adaszewski vor 4 Jahren
Ursprung
Commit
ecabdc0540
4 geänderte Dateien mit 69 neuen und 0 gelöschten Zeilen
  1. +0
    -0
      src/decagon_pytorch/batch.py
  2. +38
    -0
      src/decagon_pytorch/data.py
  3. +18
    -0
      src/decagon_pytorch/normalize.py
  4. +13
    -0
      tests/decagon_pytorch/test_data.py

+ 0
- 0
src/decagon_pytorch/batch.py Datei anzeigen


+ 38
- 0
src/decagon_pytorch/data.py Datei anzeigen

@@ -0,0 +1,38 @@
class Data(object):
def __init__(self):
self.node_types = []
self.relation_types = []
def add_node_type(self, name):
self.node_types.append(name)
def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
n = len(self.node_types)
if node_type_row >= n or node_type_column >= n:
raise ValueError('Node type index out of bounds, add node type first')
self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
def __repr__(self):
n = len(self.node_types)
if n == 0:
return 'Empty GNN Data'
s = ''
s += 'GNN Data with:\n'
s += '- ' + str(n) + ' node type(s):\n'
for nt in self.node_types:
s += ' - ' + nt + '\n'
if len(self.relation_types) == 0:
s += '- No relation types\n'
return s.strip()
s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
for i in range(n):
for j in range(n):
rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
if len(rels) == 0:
continue
# dir = '<->' if i == j else '->'
dir = '--'
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n'
for r in rels:
s += ' - ' + r[3] + '\n'
return s.strip()

+ 18
- 0
src/decagon_pytorch/normalize.py Datei anzeigen

@@ -0,0 +1,18 @@
import numpy as np
import scipy.sparse as sp
def normalize_adjacency_matrix(self, adj):
adj = sp.coo_matrix(adj)
if adj.shape[0] == adj.shape[1]:
adj_ = adj + sp.eye(adj.shape[0])
rowsum = np.array(adj_.sum(1))
degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
else:
rowsum = np.array(adj.sum(1))
colsum = np.array(adj.sum(0))
rowdegree_mat_inv = sp.diags(np.nan_to_num(np.power(rowsum, -0.5)).flatten())
coldegree_mat_inv = sp.diags(np.nan_to_num(np.power(colsum, -0.5)).flatten())
adj_normalized = rowdegree_mat_inv.dot(adj).dot(coldegree_mat_inv).tocoo()
return preprocessing.sparse_to_tuple(adj_normalized)

+ 13
- 0
tests/decagon_pytorch/test_data.py Datei anzeigen

@@ -0,0 +1,13 @@
from decagon_pytorch.data import Data
def test_data():
d = Data()
d.add_node_type('Gene')
d.add_node_type('Drug')
d.add_relation(1, 0, None, 'Target')
d.add_relation(0, 0, None, 'Interaction')
d.add_relation(1, 1, None, 'Side Effect: Nausea')
d.add_relation(1, 1, None, 'Side Effect: Infertility')
d.add_relation(1, 1, None, 'Side Effect: Death')
print(d)

Laden…
Abbrechen
Speichern