diff --git a/src/decagon_pytorch/batch.py b/src/decagon_pytorch/batch.py new file mode 100644 index 0000000..e69de29 diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py new file mode 100644 index 0000000..736f558 --- /dev/null +++ b/src/decagon_pytorch/data.py @@ -0,0 +1,38 @@ +class Data(object): + def __init__(self): + self.node_types = [] + self.relation_types = [] + + def add_node_type(self, name): + self.node_types.append(name) + + def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name): + n = len(self.node_types) + if node_type_row >= n or node_type_column >= n: + raise ValueError('Node type index out of bounds, add node type first') + self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name)) + + def __repr__(self): + n = len(self.node_types) + if n == 0: + return 'Empty GNN Data' + s = '' + s += 'GNN Data with:\n' + s += '- ' + str(n) + ' node type(s):\n' + for nt in self.node_types: + s += ' - ' + nt + '\n' + if len(self.relation_types) == 0: + s += '- No relation types\n' + return s.strip() + s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n' + for i in range(n): + for j in range(n): + rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) + if len(rels) == 0: + continue + # dir = '<->' if i == j else '->' + dir = '--' + s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n' + for r in rels: + s += ' - ' + r[3] + '\n' + return s.strip() diff --git a/src/decagon_pytorch/normalize.py b/src/decagon_pytorch/normalize.py new file mode 100644 index 0000000..54ddb8e --- /dev/null +++ b/src/decagon_pytorch/normalize.py @@ -0,0 +1,18 @@ +import numpy as np +import scipy.sparse as sp + + +def normalize_adjacency_matrix(self, adj): + adj = sp.coo_matrix(adj) + if adj.shape[0] == adj.shape[1]: + adj_ = adj + sp.eye(adj.shape[0]) + rowsum = np.array(adj_.sum(1)) + degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten()) + adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() + else: + rowsum = np.array(adj.sum(1)) + colsum = np.array(adj.sum(0)) + rowdegree_mat_inv = sp.diags(np.nan_to_num(np.power(rowsum, -0.5)).flatten()) + coldegree_mat_inv = sp.diags(np.nan_to_num(np.power(colsum, -0.5)).flatten()) + adj_normalized = rowdegree_mat_inv.dot(adj).dot(coldegree_mat_inv).tocoo() + return preprocessing.sparse_to_tuple(adj_normalized) diff --git a/tests/decagon_pytorch/test_data.py b/tests/decagon_pytorch/test_data.py new file mode 100644 index 0000000..b0c2b56 --- /dev/null +++ b/tests/decagon_pytorch/test_data.py @@ -0,0 +1,13 @@ +from decagon_pytorch.data import Data + + +def test_data(): + d = Data() + d.add_node_type('Gene') + d.add_node_type('Drug') + d.add_relation(1, 0, None, 'Target') + d.add_relation(0, 0, None, 'Interaction') + d.add_relation(1, 1, None, 'Side Effect: Nausea') + d.add_relation(1, 1, None, 'Side Effect: Infertility') + d.add_relation(1, 1, None, 'Side Effect: Death') + print(d)