From 7eda6bdfb90ef976a7d5891e82ee9e15b707ae7e Mon Sep 17 00:00:00 2001 From: Stanislaw Adaszewski Date: Thu, 11 Jun 2020 18:28:22 +0200 Subject: [PATCH] Fix regressions in trainprep. --- src/icosagon/data.py | 47 ++++++++++++++++-------------- src/icosagon/trainprep.py | 50 ++++++++++++++++++++------------ tests/icosagon/test_declayer.py | 3 +- tests/icosagon/test_trainprep.py | 2 +- 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 0690d29..6fc91cc 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -44,38 +44,41 @@ class NodeType(object): @dataclass -class RelationType(object): +class RelationTypeBase(object): name: str node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor two_way: bool + + +@dataclass +class RelationType(RelationTypeBase): hints: Dict[str, Any] = field(default_factory=dict) -class RelationFamily(object): - def __init__(self, - data: 'Data', - name: str, - node_type_row: int, - node_type_column: int, - is_symmetric: bool, - decoder_class: Type) -> None: +@dataclass +class RelationFamilyBase(object): + data: 'Data' + name: str + node_type_row: int + node_type_column: int + is_symmetric: bool + decoder_class: Type - if not is_symmetric and \ - decoder_class != DEDICOMDecoder and \ - decoder_class != BilinearDecoder: - raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') - self.data = data - self.name = name - self.node_type_row = node_type_row - self.node_type_column = node_type_column - self.is_symmetric = is_symmetric - self.decoder_class = decoder_class +@dataclass +class RelationFamily(RelationFamilyBase): + relation_types: Dict[Tuple[int, int], List[RelationType]] = None + + def __post_init__(self) -> None: + if not self.is_symmetric and \ + self.decoder_class != DEDICOMDecoder and \ + self.decoder_class != BilinearDecoder: + raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') - self.relation_types = { (node_type_row, node_type_column): [], - (node_type_column, node_type_row): [] } + self.relation_types = { (self.node_type_row, self.node_type_column): [], + (self.node_type_column, self.node_type_row): [] } def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, @@ -166,7 +169,7 @@ class RelationFamily(object): class Data(object): node_types: List[NodeType] - relation_types: Dict[Tuple[int, int], List[RelationType]] + relation_families: List[RelationFamily] def __init__(self) -> None: self.node_types = [] diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index a505d9b..d1dcdb3 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -6,13 +6,17 @@ from .sampling import fixed_unigram_candidate_sampler import torch -from dataclasses import dataclass +from dataclasses import dataclass, \ + field from typing import Any, \ List, \ Tuple, \ Dict from .data import NodeType, \ RelationType, \ + RelationTypeBase, \ + RelationFamily, \ + RelationFamilyBase, \ Data from collections import defaultdict from .normalize import norm_adj_mat_one_node_type, \ @@ -28,25 +32,20 @@ class TrainValTest(object): @dataclass -class PreparedEdges(object): - positive: TrainValTest - negative: TrainValTest +class PreparedRelationType(RelationTypeBase): + edges_pos: TrainValTest + edges_neg: TrainValTest @dataclass -class PreparedRelationType(object): - name: str - node_type_row: int - node_type_column: int - adjacency_matrix: torch.Tensor - edges_pos: TrainValTest - edges_neg: TrainValTest +class PreparedRelationFamily(RelationFamilyBase): + relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] @dataclass class PreparedData(object): node_types: List[NodeType] - relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] + relation_families: List[PreparedRelationFamily] def train_val_test_split_edges(edges: torch.Tensor, @@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType, adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, - adj_mat_train, edges_pos, edges_neg) + adj_mat_train, r.two_way, edges_pos, edges_neg) -def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: - if not isinstance(data, Data): - raise ValueError('data must be of class Data') +def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: + relation_types = { (fam.node_type_row, fam.node_type_column): [], + (fam.node_type_column, fam.node_type_row): [] } - relation_types = defaultdict(list) - for (node_type_row, node_type_column), rels in data.relation_types.items(): + for (node_type_row, node_type_column), rels in fam.relation_types.items(): for r in rels: relation_types[node_type_row, node_type_column].append( prepare_relation_type(r, ratios)) - return PreparedData(data.node_types, relation_types) + + return PreparedRelationFamily(fam.data, fam.name, + fam.node_type_row, fam.node_type_column, + fam.is_symmetric, fam.decoder_class, + relation_types) + + +def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: + if not isinstance(data, Data): + raise ValueError('data must be of class Data') + + relation_families = [ prepare_relation_family(fam) \ + for fam in data.relation_families ] + + return PreparedData(data.node_types, relation_families) diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index 3536387..7da2ad1 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -17,7 +17,8 @@ import torch def test_decode_layer_01(): d = Data() d.add_node_type('Dummy', 100) - d.add_relation_type('Dummy Relation 1', 0, 0, + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) + fam.add_relation_type('Dummy Relation 1', 0, 0, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) in_layer = OneHotInputLayer(d) diff --git a/tests/icosagon/test_trainprep.py b/tests/icosagon/test_trainprep.py index 967bb1e..712d8c5 100644 --- a/tests/icosagon/test_trainprep.py +++ b/tests/icosagon/test_trainprep.py @@ -105,7 +105,7 @@ def test_prepare_adj_mat_02(): def test_prepare_relation_type_01(): adj_mat = (torch.rand((10, 10)) > .5) - r = RelationType('Test', 0, 0, adj_mat) + r = RelationType('Test', 0, 0, adj_mat, True) ratios = TrainValTest(.8, .1, .1) _ = prepare_relation_type(r, ratios)