diff --git a/src/decagon_pytorch/batch.py b/src/decagon_pytorch/batch.py index e69de29..8be3fe6 100644 --- a/src/decagon_pytorch/batch.py +++ b/src/decagon_pytorch/batch.py @@ -0,0 +1,43 @@ +import scipy.sparse as sp + + +class Batch(object): + def __init__(self, adjacency_matrix): + pass + + def get(size): + pass + + +def train_test_split(data, train_size=.8): + pass + + +class Minibatch(object): + def __init__(self, data, node_type_row, node_type_column, size): + self.data = data + self.adjacency_matrix = data.get_adjacency_matrix(node_type_row, node_type_column) + self.size = size + self.order = np.random.permutation(adjacency_matrix.nnz) + self.count = 0 + + def reset(self): + self.count = 0 + self.order = np.random.permutation(adjacency_matrix.nnz) + + def __iter__(self): + adj_mat = self.adjacency_matrix + size = self.size + order = np.random.permutation(adj_mat.nnz) + for i in range(0, len(order), size): + row = adj_mat.row[i:i + size] + col = adj_mat.col[i:i + size] + data = adj_mat.data[i:i + size] + adj_mat_batch = sp.coo_matrix((data, (row, col)), shape=adj_mat.shape) + yield adj_mat_batch + degree = self.adjacency_matrix.sum(1) + + + + def __len__(self): + pass diff --git a/src/decagon_pytorch/data.py b/src/decagon_pytorch/data.py index 736f558..3c62469 100644 --- a/src/decagon_pytorch/data.py +++ b/src/decagon_pytorch/data.py @@ -1,16 +1,36 @@ +from collections import defaultdict +from .decode import BilinearDecoder +from .weights import init_glorot + + class Data(object): def __init__(self): self.node_types = [] self.relation_types = [] + self.decoder_types = defaultdict(lambda: BilinearDecoder) + self.latent_node = [] - def add_node_type(self, name): + def add_node_type(self, name, count, latent_length): self.node_types.append(name) + self.latent_node.append(init_glorot(count, latent_length)) def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name): n = len(self.node_types) if node_type_row >= n or node_type_column >= n: raise ValueError('Node type index out of bounds, add node type first') self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name)) + _ = self.decoder_types[(node_type_row, node_type_column)] + + def set_decoder_type(self, node_type_row, node_type_column, decoder_class): + if (node_type_row, node_type_column) not in self.decoder_types: + raise ValueError('Relation type not found, add relation first') + self.decoder_types[(node_type_row, node_type_column)] = decoder_class + + def get_adjacency_matrices(self, node_type_row, node_type_column): + rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types) + if len(rels) == 0: + + def __repr__(self): n = len(self.node_types) @@ -32,7 +52,7 @@ class Data(object): continue # dir = '<->' if i == j else '->' dir = '--' - s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n' + s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n' for r in rels: s += ' - ' + r[3] + '\n' return s.strip() diff --git a/tests/decagon_pytorch/test_data.py b/tests/decagon_pytorch/test_data.py index b0c2b56..b37c14b 100644 --- a/tests/decagon_pytorch/test_data.py +++ b/tests/decagon_pytorch/test_data.py @@ -1,4 +1,5 @@ from decagon_pytorch.data import Data +from decagon_pytorch.decode import DEDICOMDecoder def test_data(): @@ -10,4 +11,5 @@ def test_data(): d.add_relation(1, 1, None, 'Side Effect: Nausea') d.add_relation(1, 1, None, 'Side Effect: Infertility') d.add_relation(1, 1, None, 'Side Effect: Death') + d.set_decoder_type(1, 1, DEDICOMDecoder) print(d)