| @@ -44,38 +44,41 @@ class NodeType(object): | |||
| @dataclass | |||
| class RelationType(object): | |||
| class RelationTypeBase(object): | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| adjacency_matrix: torch.Tensor | |||
| two_way: bool | |||
| @dataclass | |||
| class RelationType(RelationTypeBase): | |||
| hints: Dict[str, Any] = field(default_factory=dict) | |||
| class RelationFamily(object): | |||
| def __init__(self, | |||
| data: 'Data', | |||
| name: str, | |||
| node_type_row: int, | |||
| node_type_column: int, | |||
| is_symmetric: bool, | |||
| decoder_class: Type) -> None: | |||
| @dataclass | |||
| class RelationFamilyBase(object): | |||
| data: 'Data' | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| is_symmetric: bool | |||
| decoder_class: Type | |||
| if not is_symmetric and \ | |||
| decoder_class != DEDICOMDecoder and \ | |||
| decoder_class != BilinearDecoder: | |||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
| self.data = data | |||
| self.name = name | |||
| self.node_type_row = node_type_row | |||
| self.node_type_column = node_type_column | |||
| self.is_symmetric = is_symmetric | |||
| self.decoder_class = decoder_class | |||
| @dataclass | |||
| class RelationFamily(RelationFamilyBase): | |||
| relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||
| def __post_init__(self) -> None: | |||
| if not self.is_symmetric and \ | |||
| self.decoder_class != DEDICOMDecoder and \ | |||
| self.decoder_class != BilinearDecoder: | |||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
| self.relation_types = { (node_type_row, node_type_column): [], | |||
| (node_type_column, node_type_row): [] } | |||
| self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||
| (self.node_type_column, self.node_type_row): [] } | |||
| def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||
| adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||
| @@ -166,7 +169,7 @@ class RelationFamily(object): | |||
| class Data(object): | |||
| node_types: List[NodeType] | |||
| relation_types: Dict[Tuple[int, int], List[RelationType]] | |||
| relation_families: List[RelationFamily] | |||
| def __init__(self) -> None: | |||
| self.node_types = [] | |||
| @@ -6,13 +6,17 @@ | |||
| from .sampling import fixed_unigram_candidate_sampler | |||
| import torch | |||
| from dataclasses import dataclass | |||
| from dataclasses import dataclass, \ | |||
| field | |||
| from typing import Any, \ | |||
| List, \ | |||
| Tuple, \ | |||
| Dict | |||
| from .data import NodeType, \ | |||
| RelationType, \ | |||
| RelationTypeBase, \ | |||
| RelationFamily, \ | |||
| RelationFamilyBase, \ | |||
| Data | |||
| from collections import defaultdict | |||
| from .normalize import norm_adj_mat_one_node_type, \ | |||
| @@ -28,25 +32,20 @@ class TrainValTest(object): | |||
| @dataclass | |||
| class PreparedEdges(object): | |||
| positive: TrainValTest | |||
| negative: TrainValTest | |||
| class PreparedRelationType(RelationTypeBase): | |||
| edges_pos: TrainValTest | |||
| edges_neg: TrainValTest | |||
| @dataclass | |||
| class PreparedRelationType(object): | |||
| name: str | |||
| node_type_row: int | |||
| node_type_column: int | |||
| adjacency_matrix: torch.Tensor | |||
| edges_pos: TrainValTest | |||
| edges_neg: TrainValTest | |||
| class PreparedRelationFamily(RelationFamilyBase): | |||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
| @dataclass | |||
| class PreparedData(object): | |||
| node_types: List[NodeType] | |||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
| relation_families: List[PreparedRelationFamily] | |||
| def train_val_test_split_edges(edges: torch.Tensor, | |||
| @@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType, | |||
| adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
| adj_mat_train, edges_pos, edges_neg) | |||
| adj_mat_train, r.two_way, edges_pos, edges_neg) | |||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
| if not isinstance(data, Data): | |||
| raise ValueError('data must be of class Data') | |||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||
| relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||
| (fam.node_type_column, fam.node_type_row): [] } | |||
| relation_types = defaultdict(list) | |||
| for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
| for r in rels: | |||
| relation_types[node_type_row, node_type_column].append( | |||
| prepare_relation_type(r, ratios)) | |||
| return PreparedData(data.node_types, relation_types) | |||
| return PreparedRelationFamily(fam.data, fam.name, | |||
| fam.node_type_row, fam.node_type_column, | |||
| fam.is_symmetric, fam.decoder_class, | |||
| relation_types) | |||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
| if not isinstance(data, Data): | |||
| raise ValueError('data must be of class Data') | |||
| relation_families = [ prepare_relation_family(fam) \ | |||
| for fam in data.relation_families ] | |||
| return PreparedData(data.node_types, relation_families) | |||
| @@ -17,7 +17,8 @@ import torch | |||
| def test_decode_layer_01(): | |||
| d = Data() | |||
| d.add_node_type('Dummy', 100) | |||
| d.add_relation_type('Dummy Relation 1', 0, 0, | |||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
| fam.add_relation_type('Dummy Relation 1', 0, 0, | |||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||
| prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
| in_layer = OneHotInputLayer(d) | |||
| @@ -105,7 +105,7 @@ def test_prepare_adj_mat_02(): | |||
| def test_prepare_relation_type_01(): | |||
| adj_mat = (torch.rand((10, 10)) > .5) | |||
| r = RelationType('Test', 0, 0, adj_mat) | |||
| r = RelationType('Test', 0, 0, adj_mat, True) | |||
| ratios = TrainValTest(.8, .1, .1) | |||
| _ = prepare_relation_type(r, ratios) | |||