IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Changed Data architecture a bit.

master
Stanislaw Adaszewski 4 years ago
parent
commit
21b2565720
3 changed files with 61 additions and 30 deletions
  1. +48
    -22
      src/decagon_pytorch/data.py
  2. +6
    -0
      src/decagon_pytorch/model.py
  3. +7
    -8
      tests/decagon_pytorch/test_data.py

+ 48
- 22
src/decagon_pytorch/data.py View File

@@ -3,33 +3,53 @@ from .decode import BilinearDecoder
from .weights import init_glorot
class NodeType(object):
def __init__(self, name, count):
self.name = name
self.count = count
class RelationType(object):
def __init__(self, name, node_type_row, node_type_column,
adjacency_matrix):
self.name = name
self.node_type_row = node_type_row
self.node_type_column = node_type_column
self.adjacency_matrix = adjacency_matrix
class Data(object):
def __init__(self):
self.node_types = []
self.relation_types = []
self.decoder_types = defaultdict(lambda: BilinearDecoder)
self.latent_node = []
self.relation_types = defaultdict(list)
# self.decoder_types = defaultdict(lambda: BilinearDecoder)
# self.latent_node = []
def add_node_type(self, name, count, latent_length):
self.node_types.append(name)
self.latent_node.append(init_glorot(count, latent_length))
def add_node_type(self, name, count): # , latent_length):
self.node_types.append(NodeType(name, count))
# self.latent_node.append(init_glorot(count, latent_length))
def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix):
n = len(self.node_types)
if node_type_row >= n or node_type_column >= n:
raise ValueError('Node type index out of bounds, add node type first')
self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
_ = self.decoder_types[(node_type_row, node_type_column)]
key = (node_type_row, node_type_column)
self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix))
# _ = self.decoder_types[(node_type_row, node_type_column)]
def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
if (node_type_row, node_type_column) not in self.decoder_types:
raise ValueError('Relation type not found, add relation first')
self.decoder_types[(node_type_row, node_type_column)] = decoder_class
#def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
# if (node_type_row, node_type_column) not in self.decoder_types:
# raise ValueError('Relation type not found, add relation first')
# self.decoder_types[(node_type_row, node_type_column)] = decoder_class
def get_adjacency_matrices(self, node_type_row, node_type_column):
rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
if len(rels) == 0:
# rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
key = (node_type_row, node_type_column)
if key not in self.relation_types:
raise ValueError('Relation type not found')
rels = self.relation_types[key]
rels = list(map(lambda a: a.adjacency_matrix, rels))
return rels
def __repr__(self):
@@ -40,19 +60,25 @@ class Data(object):
s += 'GNN Data with:\n'
s += '- ' + str(n) + ' node type(s):\n'
for nt in self.node_types:
s += ' - ' + nt + '\n'
s += ' - ' + nt.name + '\n'
if len(self.relation_types) == 0:
s += '- No relation types\n'
return s.strip()
s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n'
n = sum(map(len, self.relation_types))
s += '- ' + str(n) + ' relation type(s):\n'
for i in range(n):
for j in range(n):
rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
if len(rels) == 0:
key = (i, j)
if key not in self.relation_types:
continue
rels = self.relation_types[key]
# rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types))
#if len(rels) == 0:
# continue
# dir = '<->' if i == j else '->'
dir = '--'
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n'
s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n'
#' (' + self.decoder_types[(i, j)].__name__ + '):\n'
for r in rels:
s += ' - ' + r[3] + '\n'
s += ' - ' + r.name + '\n'
return s.strip()

+ 6
- 0
src/decagon_pytorch/model.py View File

@@ -0,0 +1,6 @@
class Model(object):
def __init__(self, data):
self.data = data
def build(self):
pass

+ 7
- 8
tests/decagon_pytorch/test_data.py View File

@@ -4,12 +4,11 @@ from decagon_pytorch.decode import DEDICOMDecoder
def test_data():
d = Data()
d.add_node_type('Gene')
d.add_node_type('Drug')
d.add_relation(1, 0, None, 'Target')
d.add_relation(0, 0, None, 'Interaction')
d.add_relation(1, 1, None, 'Side Effect: Nausea')
d.add_relation(1, 1, None, 'Side Effect: Infertility')
d.add_relation(1, 1, None, 'Side Effect: Death')
d.set_decoder_type(1, 1, DEDICOMDecoder)
d.add_node_type('Gene', 1000)
d.add_node_type('Drug', 100)
d.add_relation_type('Target', 1, 0, None)
d.add_relation_type('Interaction', 0, 0, None)
d.add_relation_type('Side Effect: Nausea', 1, 1, None)
d.add_relation_type('Side Effect: Infertility', 1, 1, None)
d.add_relation_type('Side Effect: Death', 1, 1, None)
print(d)

Loading…
Cancel
Save