IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Start implementing Minibatch.

master
Stanislaw Adaszewski 4 years ago
parent
commit
71bf338491
3 changed files with 67 additions and 2 deletions
  1. +43
    -0
      src/decagon_pytorch/batch.py
  2. +22
    -2
      src/decagon_pytorch/data.py
  3. +2
    -0
      tests/decagon_pytorch/test_data.py

+ 43
- 0
src/decagon_pytorch/batch.py View File

@@ -0,0 +1,43 @@
import scipy.sparse as sp
class Batch(object):
def __init__(self, adjacency_matrix):
pass
def get(size):
pass
def train_test_split(data, train_size=.8):
pass
class Minibatch(object):
def __init__(self, data, node_type_row, node_type_column, size):
self.data = data
self.adjacency_matrix = data.get_adjacency_matrix(node_type_row, node_type_column)
self.size = size
self.order = np.random.permutation(adjacency_matrix.nnz)
self.count = 0
def reset(self):
self.count = 0
self.order = np.random.permutation(adjacency_matrix.nnz)
def __iter__(self):
adj_mat = self.adjacency_matrix
size = self.size
order = np.random.permutation(adj_mat.nnz)
for i in range(0, len(order), size):
row = adj_mat.row[i:i + size]
col = adj_mat.col[i:i + size]
data = adj_mat.data[i:i + size]
adj_mat_batch = sp.coo_matrix((data, (row, col)), shape=adj_mat.shape)
yield adj_mat_batch
degree = self.adjacency_matrix.sum(1)
def __len__(self):
pass

+ 22
- 2
src/decagon_pytorch/data.py View File

@@ -1,16 +1,36 @@
from collections import defaultdict
from .decode import BilinearDecoder
from .weights import init_glorot
class Data(object):
def __init__(self):
self.node_types = []
self.relation_types = []
self.decoder_types = defaultdict(lambda: BilinearDecoder)
self.latent_node = []
def add_node_type(self, name):
def add_node_type(self, name, count, latent_length):
self.node_types.append(name)
self.latent_node.append(init_glorot(count, latent_length))
def add_relation(self, node_type_row, node_type_column, adjacency_matrix, name):
n = len(self.node_types)
if node_type_row >= n or node_type_column >= n:
raise ValueError('Node type index out of bounds, add node type first')
self.relation_types.append((node_type_row, node_type_column, adjacency_matrix, name))
_ = self.decoder_types[(node_type_row, node_type_column)]
def set_decoder_type(self, node_type_row, node_type_column, decoder_class):
if (node_type_row, node_type_column) not in self.decoder_types:
raise ValueError('Relation type not found, add relation first')
self.decoder_types[(node_type_row, node_type_column)] = decoder_class
def get_adjacency_matrices(self, node_type_row, node_type_column):
rels = list(filter(lambda a: a[0] == node_type_row and a[1] == node_type_column), self.relation_types)
if len(rels) == 0:
def __repr__(self):
n = len(self.node_types)
@@ -32,7 +52,7 @@ class Data(object):
continue
# dir = '<->' if i == j else '->'
dir = '--'
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ':\n'
s += ' - ' + self.node_types[i] + ' ' + dir + ' ' + self.node_types[j] + ' (' + self.decoder_types[(i, j)].__name__ + '):\n'
for r in rels:
s += ' - ' + r[3] + '\n'
return s.strip()

+ 2
- 0
tests/decagon_pytorch/test_data.py View File

@@ -1,4 +1,5 @@
from decagon_pytorch.data import Data
from decagon_pytorch.decode import DEDICOMDecoder
def test_data():
@@ -10,4 +11,5 @@ def test_data():
d.add_relation(1, 1, None, 'Side Effect: Nausea')
d.add_relation(1, 1, None, 'Side Effect: Infertility')
d.add_relation(1, 1, None, 'Side Effect: Death')
d.set_decoder_type(1, 1, DEDICOMDecoder)
print(d)

Loading…
Cancel
Save