IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add databatch.

master
Stanislaw Adaszewski 3 years ago
parent
commit
c210be41b3
2 changed files with 172 additions and 0 deletions
  1. +91
    -0
      src/icosagon/databatch.py
  2. +81
    -0
      tests/icosagon/test_databatch.py

+ 91
- 0
src/icosagon/databatch.py View File

@@ -0,0 +1,91 @@
from icosagon.trainprep import PreparedData, \
PreparedRelationFamily, \
PreparedRelationType, \
_empty_edge_list_tvt
class BatchedData(PreparedData):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def batched_data_skeleton(data: PreparedData) -> BatchedData:
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
fam_skels = []
for fam in data.relation_families:
rel_types_skel = []
for rel in fam.relation_types:
rel_skel = PreparedRelationType(rel.name,
rel.node_type_row, rel.node_type_column,
rel.adjacency_matrix, rel.adjacency_matrix_backward,
_empty_edge_list_tvt(), _empty_edge_list_tvt(),
_empty_edge_list_tvt(), _empty_edge_list_tvt())
rel_types_skel.append(rel_skel)
fam_skels.append(PreparedRelationFamily(fam.data, fam.name,
fam.node_type_row, fam.node_type_column,
fam.is_symmetric, fam.decoder_class,
rel_types_skel))
return BatchedData(data.node_types, fam_skels)
class DataBatcher(object):
def __init__(self, data: PreparedData, batch_size: int) -> None:
self._check_params(data, batch_size)
self.data = data
self.batch_size = batch_size
# def batched_data_iter(self, fam_idx: int, rel_idx: int,
# part_type: str) -> BatchedData:
#
# rel = self.data.relation_families[fam_idx].relation_types[rel_idx]
#
# edges = getattr(rel.edges_pos, part_type)
# for m in range(0, len(edges), self.batch_size):
# batched_data = batched_data_skeleton(self.data)
# setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_pos,
# part_type, edges[m : m + self.batch_size])
# yield batched_data
#
# edges = getattr(rel.edges_neg, part_type)
# for m in range(0, len(edges), self.batch_size):
# batched_data = batched_data_skeleton(self.data)
# setattr(batched_data.relation_families[fam_idx].relation_types[rel_idx].edges_neg,
# part_type, edges[m : m + self.batch_size])
# yield batched_data
#
# edges = getattr(rel.edges_pos_back, part_type)
# for m in range(0, len(edges), self.batch_size):
# batched_data = batched_data_skeleton(self.data)
# setattr(batched_data.relation_families[i].relation_types[k].edges_pos_back,
# part_type, edges[m : m + self.batch_size])
# yield batched_data
#
# edges = getattr(rel.edges_neg_back, part_type)
# for m in range(0, len(), self.batch_size):
# batched_data = batched_data_skeleton(self.data)
# setattr(batched_data.relation_families[i].relation_types[k].edges_neg_back,
# edges[m : m + self.batch_size])
# yield batched_data
def __iter__(self) -> BatchedData:
for i, fam in enumerate(self.data.relation_families):
for k, rel in enumerate(fam.relation_types):
for edge_type in ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']:
for part_type in ['train', 'val', 'test']:
edges = getattr(getattr(rel, edge_type), part_type)
for m in range(0, len(edges), self.batch_size):
batched_data = batched_data_skeleton(self.data)
setattr(getattr(batched_data.relation_families[i].relation_types[k],
edge_type), part_type, edges[m : m + self.batch_size])
yield batched_data
@staticmethod
def _check_params(data, batch_size):
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if not isinstance(batch_size, int):
raise TypeError('batch_size must be an int')

+ 81
- 0
tests/icosagon/test_databatch.py View File

@@ -0,0 +1,81 @@
from icosagon.databatch import DataBatcher, \
BatchedData
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
TrainValTest
import torch
def _some_data():
data = Data()
data.add_node_type('Foo', 100)
data.add_node_type('Bar', 500)
fam = data.add_relation_family('Foo-Bar', 0, 1, True)
adj_mat = torch.rand(100, 500).round().to_sparse()
fam.add_relation_type('Foo-Bar', adj_mat)
return data
def test_data_batcher_01():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
def test_data_batcher_02():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
for batch_d in batcher:
pass
def test_data_batcher_03():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
for batch_d in batcher:
edges_list = []
for fam in batch_d.relation_families:
for rel in fam.relation_types:
for edge_type in ['edges_pos', 'edges_neg',
'edges_back_pos', 'edges_back_neg']:
for part_type in ['train', 'val', 'test']:
edges = getattr(getattr(rel, edge_type), part_type)
edges_list.append(edges)
assert sum([ 1 for edges in edges_list if len(edges) > 0 ]) == 1
def test_data_batcher_04():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
edges_list = []
for batch_d in batcher:
for fam in batch_d.relation_families:
for rel in fam.relation_types:
for edge_type in ['edges_pos', 'edges_neg',
'edges_back_pos', 'edges_back_neg']:
for part_type in ['train', 'val', 'test']:
edges = getattr(getattr(rel, edge_type), part_type)
edges_list.append(edges)
assert sum([ len(edges) for edges in edges_list ]) == \
torch.sum(data.relation_families[0].relation_types[0].adjacency_matrix._values()) * 2
def test_data_batcher_05():
data = _some_data()
prep_d = prepare_training(data, TrainValTest(.8, .1, .1))
batcher = DataBatcher(prep_d, 512)
for batch_d in batcher:
edges_list = []
for fam in batch_d.relation_families:
for rel in fam.relation_types:
for edge_type in ['edges_pos', 'edges_neg',
'edges_back_pos', 'edges_back_neg']:
for part_type in ['train', 'val', 'test']:
edges = getattr(getattr(rel, edge_type), part_type)
edges_list.append(edges)
assert all([ len(edges) <= 512 for edges in edges_list ])
assert not all([ len(edges) == 0 for edges in edges_list ])
print(sum(map(len, edges_list)))

Loading…
Cancel
Save