diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index 3c62469..0891c8b 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -3,33 +3,53 @@ from .decode import BilinearDecoder from .weights import init_glorot +class NodeType(object): + def __init__(self, name, count): + self.name = name + self.count = count + + +class RelationType(object): + def __init__(self, name, node_type_row, node_type_column, + adjacency_matrix): + self.name = name + self.node_type_row = node_type_row + self.node_type_column = node_type_column + self.adjacency_matrix = adjacency_matrix + + class Data(object): def __init__(self): self.node_types = [] - self.relation_types = [] - self.decoder_types = defaultdict(lambda: BilinearDecoder) - self.latent_node = [] + self.relation_types = defaultdict(list) + # self.decoder_types = defaultdict(lambda: BilinearDecoder) + # self.latent_node = [] - def add_node_type(self, name, count, latent_length): - self.node_types.append(name) - self.latent_node.append(init_glorot(count, latent_length)) + def add_node_type(self, name, count): # , latent_length): + self.node_types.append(NodeType(name, count)) + # self.latent_node.append(init_glorot(count, latent_length)) - def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name): + def add_relation_type(self, name, node_type_row, node_type_column, adjacency_matrix): n = len(self.node_types) if node_type_row >= n or node_type_column >= n: raise ValueError('Node type index out of bounds, add node type first') - self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name)) - _ = self.decoder_types[(node_type_row, node_type_column)] + key = (node_type_row, node_type_column) + self.relation_types[key].append(RelationType(name, node_type_row, node_type_column, adjacency_matrix)) + # _ = self.decoder_types[(node_type_row, node_type_column)] - def set_decoder_type(self, node_type_row, node_type_column, decoder_class): - if (node_type_row, node_type_column) not in self.decoder_types: - raise ValueError('Relation type not found, add relation first') - self.decoder_types[(node_type_row, node_type_column)] = decoder_class + #def set_decoder_type(self, node_type_row, node_type_column, decoder_class): + # if (node_type_row, node_type_column) not in self.decoder_types: + # raise ValueError('Relation type not found, add relation first') + # self.decoder_types[(node_type_row, node_type_column)] = decoder_class def get_adjacency_matrices(self, node_type_row, node_type_column): - rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) - if len(rels) == 0: - + # rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) + key = (node_type_row, node_type_column) + if key not in self.relation_types: + raise ValueError('Relation type not found') + rels = self.relation_types[key] + rels = list(map(lambda a: a.adjacency_matrix, rels)) + return rels def __repr__(self): @@ -40,19 +60,25 @@ class Data(object): s += 'GNN Data with:\n' s += '- ' + str(n) + ' node type(s):\n' for nt in self.node_types: - s += ' - ' + nt + '\n' + s += ' - ' + nt.name + '\n' if len(self.relation_types) == 0: s += '- No relation types\n' return s.strip() - s += '- ' + str(len(self.relation_types)) + ' relation type(s):\n' + n = sum(map(len, self.relation_types)) + s += '- ' + str(n) + ' relation type(s):\n' for i in range(n): for j in range(n): - rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) - if len(rels) == 0: + key = (i, j) + if key not in self.relation_types: continue + rels = self.relation_types[key] + # rels = list(filter(lambda a: a[0] == i and a[1] == j, self.relation_types)) + #if len(rels) == 0: + # continue # dir = '<->' if i == j else '->' dir = '--' - s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n' + s += ' - ' + self.node_types[i].name + ' ' + dir + ' ' + self.node_types[j].name + ':\n' + #' (' + self.decoder_types[(i, j)].__name__ + '):\n' for r in rels: - s += ' - ' + r[3] + '\n' + s += ' - ' + r.name + '\n' return s.strip() diff --git a/src/decagon_pytorch/model.py b/src/decagon_pytorch/model.py index e69de29..0e6380b 100644 --- a/src/decagon_pytorch/model.py +++ b/src/decagon_pytorch/model.py @@ -0,0 +1,6 @@ +class Model(object): + def __init__(self, data): + self.data = data + + def build(self): + pass diff --git a/tests/decagon_pytorch/test_data.py b/tests/decagon_pytorch/test_data.py index b37c14b..fc6c111 100644 --- a/tests/decagon_pytorch/test_data.py +++ b/tests/decagon_pytorch/test_data.py @@ -4,12 +4,11 @@ from decagon_pytorch.decode import DEDICOMDecoder def test_data(): d = Data() - d.add_node_type('Gene') - d.add_node_type('Drug') - d.add_relation(1, 0, None, 'Target') - d.add_relation(0, 0, None, 'Interaction') - d.add_relation(1, 1, None, 'Side Effect: Nausea') - d.add_relation(1, 1, None, 'Side Effect: Infertility') - d.add_relation(1, 1, None, 'Side Effect: Death') - d.set_decoder_type(1, 1, DEDICOMDecoder) + d.add_node_type('Gene', 1000) + d.add_node_type('Drug', 100) + d.add_relation_type('Target', 1, 0, None) + d.add_relation_type('Interaction', 0, 0, None) + d.add_relation_type('Side Effect: Nausea', 1, 1, None) + d.add_relation_type('Side Effect: Infertility', 1, 1, None) + d.add_relation_type('Side Effect: Death', 1, 1, None) print(d)