|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import torch
- from typing import List
- from .trainprep import PreparedData
- from dataclasses import dataclass
- import random
- from collections import defaultdict
-
-
- @dataclass
- class TrainingBatch(object):
- relation_family_index: int
- relation_type_index: int
- node_type_row: int
- node_type_column: int
- edges: torch.Tensor
-
-
- class FastBatcher(object):
- def __init__(self,
- prep_d: PreparedData,
- batch_size: int) -> None:
-
- if not isinstance(prep_d, PreparedData):
- raise TypeError('prep_d must be an instance of PreparedData')
-
- self.prep_d = prep_d
- self.batch_size = int(batch_size)
-
- self.edges = None
- self.build()
-
- def build(self):
- self.edges = []
- for fam_idx, fam in enumerate(self.prep_d.relation_families):
- edges = []
- targets = []
- edges_back = []
- targets_back = []
- for rel_idx, rel in enumerate(fam.relation_types):
- edges.append(rel.edges_pos.train)
- edges.append(rel.edges_neg.train)
- targets.append(torch.ones(len(rel.edges_pos.train)))
- targets.append(torch.zeros(len(rel.edges_neg.train)))
-
- edges_back.append(rel.edges_back_pos.train)
- edges_back.append(rel.edges_back_neg.train)
- targets_back.apend(torch.zeros(len(rel.edges_back_pos.train)))
- targets_back.apend(torch.zeros(len(rel.edges_back_neg.train)))
-
- edges = torch.cat(edges)
- targets = torch.cat(targets)
- edges_back = torch.cat(edges_back)
- targets_back = torch.cat(targets_back)
-
- order = torch.randperm(len(edges))
- edges = edges[order]
- targets = targets[order]
-
- order_back = torch.randperm(len(edges_back))
- edges_back = edges_back[order_back]
- targets_back = targets_back[order_back]
-
- self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': False,
- 'edges': edges, 'targets': targets, 'ofs': 0})
- self.edges.append({'fam_idx': fam_idx, 'rel_idx': rel_idx, 'back': True,
- 'edges': edges_back, 'targets': targets_back, 'ofs': 0})
-
- def __iter__(self):
- while True:
- edges = [ e for e in self.edges \
- if e['ofs'] < len(e['edges']) ]
- # TODO: need to finish this
-
- def __iter_old__(self):
- edge_types = ['edges_pos', 'edges_neg', 'edges_back_pos', 'edges_back_neg']
-
- offsets = {}
- orders = {}
- done = {}
-
- for fam_idx, fam in enumerate(self.prep_d.relation_families):
- for rel_idx, rel in enumerate(fam.relation_types):
- for et in edge_types:
- done[fam_idx, rel_idx, et] = False
-
- while True:
- fam_idx = torch.randint(0, len(self.prep_d.relation_families), (1,)).item()
- fam = self.prep_d.relation_families[fam_idx]
-
- rel_idx = torch.randint(0, len(fam.relation_types), (1,)).item()
- rel = fam.relation_types[rel_idx]
-
- et = random.choice(edge_types)
- edges = getattr(rel, et).train
-
- key = (fam_idx, rel_idx, et)
- if key not in orders:
- orders[key] = torch.randperm(len(edges))
- offsets[key] = 0
-
- ord = orders[key]
- ofs = offsets[key]
-
- nt_row = rel.node_type_row
- nt_col = rel.node_type_column
-
- if 'back' in et:
- nt_row, nt_col = nt_col, nt_row
-
- if ofs < len(edges):
- offsets[key] += self.batch_size
- ord = ord[ofs:ofs+self.batch_size]
- edges = edges[ord]
- yield TrainingBatch(fam_idx, rel_idx, nt_row, nt_column, edges)
- else:
- done[key] = True
-
-
-
-
- for fam in self.prep_d.relation_families:
- edges = []
- for rel in fam.relation_types:
- edges.append(rel.edges_pos.train)
- edges.append(rel.edges_back_pos.train)
- edges.append(rel.edges_neg.train)
- edges.append(rel.edges_back_neg.train)
- edges = torch.cat(e)
-
-
-
- class FastDecLayer(torch.nn.Module):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- def forward(self,
- last_layer_repr: List[torch.Tensor],
- training_batch: TrainingBatch):
|