IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Kaynağa Gözat

Fix regressions in trainprep.

master
Stanislaw Adaszewski 4 yıl önce
ebeveyn
işleme
7eda6bdfb9
4 değiştirilmiş dosya ile 59 ekleme ve 43 silme
  1. +25
    -22
      src/icosagon/data.py
  2. +31
    -19
      src/icosagon/trainprep.py
  3. +2
    -1
      tests/icosagon/test_declayer.py
  4. +1
    -1
      tests/icosagon/test_trainprep.py

+ 25
- 22
src/icosagon/data.py Dosyayı Görüntüle

@@ -44,38 +44,41 @@ class NodeType(object):
@dataclass
class RelationType(object):
class RelationTypeBase(object):
name: str
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
two_way: bool
@dataclass
class RelationType(RelationTypeBase):
hints: Dict[str, Any] = field(default_factory=dict)
class RelationFamily(object):
def __init__(self,
data: 'Data',
name: str,
node_type_row: int,
node_type_column: int,
is_symmetric: bool,
decoder_class: Type) -> None:
@dataclass
class RelationFamilyBase(object):
data: 'Data'
name: str
node_type_row: int
node_type_column: int
is_symmetric: bool
decoder_class: Type
if not is_symmetric and \
decoder_class != DEDICOMDecoder and \
decoder_class != BilinearDecoder:
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
self.data = data
self.name = name
self.node_type_row = node_type_row
self.node_type_column = node_type_column
self.is_symmetric = is_symmetric
self.decoder_class = decoder_class
@dataclass
class RelationFamily(RelationFamilyBase):
relation_types: Dict[Tuple[int, int], List[RelationType]] = None
def __post_init__(self) -> None:
if not self.is_symmetric and \
self.decoder_class != DEDICOMDecoder and \
self.decoder_class != BilinearDecoder:
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
self.relation_types = { (node_type_row, node_type_column): [],
(node_type_column, node_type_row): [] }
self.relation_types = { (self.node_type_row, self.node_type_column): [],
(self.node_type_column, self.node_type_row): [] }
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
@@ -166,7 +169,7 @@ class RelationFamily(object):
class Data(object):
node_types: List[NodeType]
relation_types: Dict[Tuple[int, int], List[RelationType]]
relation_families: List[RelationFamily]
def __init__(self) -> None:
self.node_types = []


+ 31
- 19
src/icosagon/trainprep.py Dosyayı Görüntüle

@@ -6,13 +6,17 @@
from .sampling import fixed_unigram_candidate_sampler
import torch
from dataclasses import dataclass
from dataclasses import dataclass, \
field
from typing import Any, \
List, \
Tuple, \
Dict
from .data import NodeType, \
RelationType, \
RelationTypeBase, \
RelationFamily, \
RelationFamilyBase, \
Data
from collections import defaultdict
from .normalize import norm_adj_mat_one_node_type, \
@@ -28,25 +32,20 @@ class TrainValTest(object):
@dataclass
class PreparedEdges(object):
positive: TrainValTest
negative: TrainValTest
class PreparedRelationType(RelationTypeBase):
edges_pos: TrainValTest
edges_neg: TrainValTest
@dataclass
class PreparedRelationType(object):
name: str
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
edges_pos: TrainValTest
edges_neg: TrainValTest
class PreparedRelationFamily(RelationFamilyBase):
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
@dataclass
class PreparedData(object):
node_types: List[NodeType]
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
relation_families: List[PreparedRelationFamily]
def train_val_test_split_edges(edges: torch.Tensor,
@@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType,
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
adj_mat_train, edges_pos, edges_neg)
adj_mat_train, r.two_way, edges_pos, edges_neg)
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data):
raise ValueError('data must be of class Data')
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily:
relation_types = { (fam.node_type_row, fam.node_type_column): [],
(fam.node_type_column, fam.node_type_row): [] }
relation_types = defaultdict(list)
for (node_type_row, node_type_column), rels in data.relation_types.items():
for (node_type_row, node_type_column), rels in fam.relation_types.items():
for r in rels:
relation_types[node_type_row, node_type_column].append(
prepare_relation_type(r, ratios))
return PreparedData(data.node_types, relation_types)
return PreparedRelationFamily(fam.data, fam.name,
fam.node_type_row, fam.node_type_column,
fam.is_symmetric, fam.decoder_class,
relation_types)
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data):
raise ValueError('data must be of class Data')
relation_families = [ prepare_relation_family(fam) \
for fam in data.relation_families ]
return PreparedData(data.node_types, relation_families)

+ 2
- 1
tests/icosagon/test_declayer.py Dosyayı Görüntüle

@@ -17,7 +17,8 @@ import torch
def test_decode_layer_01():
d = Data()
d.add_node_type('Dummy', 100)
d.add_relation_type('Dummy Relation 1', 0, 0,
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)


+ 1
- 1
tests/icosagon/test_trainprep.py Dosyayı Görüntüle

@@ -105,7 +105,7 @@ def test_prepare_adj_mat_02():
def test_prepare_relation_type_01():
adj_mat = (torch.rand((10, 10)) > .5)
r = RelationType('Test', 0, 0, adj_mat)
r = RelationType('Test', 0, 0, adj_mat, True)
ratios = TrainValTest(.8, .1, .1)
_ = prepare_relation_type(r, ratios)


Yükleniyor…
İptal
Kaydet