IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
ソースを参照

Attempt DualBatcher.

master
Stanislaw Adaszewski 4年前
コミット
0a061af526
2個のファイルの変更101行の追加0行の削除
  1. +61
    -0
      src/triacontagon/batch.py
  2. +40
    -0
      src/triacontagon/loop.py

+ 61
- 0
src/triacontagon/batch.py ファイルの表示

@@ -1,6 +1,7 @@
from .data import Data
from .model import TrainingBatch
import torch
from functools import reduce
def _shuffle(x: torch.Tensor) -> torch.Tensor:
@@ -8,6 +9,66 @@ def _shuffle(x: torch.Tensor) -> torch.Tensor:
return x[order]
def _same_data_org(pos_data: Data, neg_data: Data):
if len(pos_data.vertex_types) != len(neg_data.vertex_types):
return False
test = [ pos_data.vertex_types[i].name == neg_data.vertex_types[i].name \
and pos_data.vertex_types[i].count == neg_data.vertex_types[i].count \
for i in range(len(pos_data.vertex_types)) ]
if not all(test):
return False
if not set(pos_data.edge_types.keys()) == \
set(neg_data.edge_types.keys()):
return False
test = [ pos_data.edge_types[i].name == \
neg_data.edge_types[i].name \
and pos_data.edge_types[i].vertex_type_row == \
neg_data.edge_types[i].vertex_type_row \
and pos_data.edge_types[i].vertex_type_column == \
neg_data.edge_types[i].vertex_type_column \
and len(pos_data.edge_types[i].adjacency_matrices) == \
len(neg_data.edge_types[i].adjacency_matrices) \
for i in pos_data.edge_types.keys() ]
if not all(test):
return False
test = [ [ len(pos_data.edge_types[i].adjacency_matrices[k].values()) == \
len(neg_data.edge_types[i].adjacency_matrices[k].values()) \
for k in range(len(pos_data.edge_types[i])) ] \
for i in pos_data.edge_types.keys() ]
test = reduce(list.__add__, test)
if not all(test):
return False
return True
class DualBatcher(object):
def __init__(self, pos_data: Data, neg_data: Data,
batch_size: int=512, shuffle: bool=True) -> None:
if not isinstance(pos_data, Data):
raise TypeError('pos_data must be an instance of Data')
if not isinstance(neg_data, Data):
raise TypeError('neg_data must be an instance of Data')
if not _same_data_org(pos_data, neg_data):
raise ValueError('pos_data and neg_data must have the same organization')
self.pos_data = pos_data
self.neg_data = neg_data
self.batch_size = int(batch_size)
self.shuffle = bool(shuffle)
def __iter__(self):
class Batcher(object):
def __init__(self, data: Data, batch_size: int=512,
shuffle: bool=True) -> None:


+ 40
- 0
src/triacontagon/loop.py ファイルの表示

@@ -0,0 +1,40 @@
from .model import Model
from .batch import Batcher
class TrainLoop(object):
def __init__(self, model: Model,
pos_batcher: Batcher,
neg_batcher: Batcher,
max_epochs: int = 50) -> None:
if not isinstance(model, Model):
raise TypeError('model must be an instance of Model')
if not isinstance(pos_batcher, Batcher):
raise TypeError('pos_batcher must be an instance of Batcher')
if not isinstance(neg_batcher, Batcher):
raise TypeError('neg_batcher must be an instance of Batcher')
self.model = model
self.pos_batcher = pos_batcher
self.neg_batcher = neg_batcher
self.max_epochs = int(num_epochs)
def run_epoch(self) -> None:
pos_it = iter(self.pos_batcher)
neg_it = iter(self.neg_batcher)
while True:
try:
pos_batch = next(pos_it)
neg_batch = next(neg_it)
except StopIteration:
break
if len(pos_batch.edges) != len(neg_batch.edges):
raise ValueError('Positive and negative batch should have same length')
def run(self) -> None:
for epoch in range(self.max_epochs):
self.run_epoch()

読み込み中…
キャンセル
保存