| @@ -44,38 +44,41 @@ class NodeType(object): | |||||
| @dataclass | @dataclass | ||||
| class RelationType(object): | |||||
| class RelationTypeBase(object): | |||||
| name: str | name: str | ||||
| node_type_row: int | node_type_row: int | ||||
| node_type_column: int | node_type_column: int | ||||
| adjacency_matrix: torch.Tensor | adjacency_matrix: torch.Tensor | ||||
| two_way: bool | two_way: bool | ||||
| @dataclass | |||||
| class RelationType(RelationTypeBase): | |||||
| hints: Dict[str, Any] = field(default_factory=dict) | hints: Dict[str, Any] = field(default_factory=dict) | ||||
| class RelationFamily(object): | |||||
| def __init__(self, | |||||
| data: 'Data', | |||||
| name: str, | |||||
| node_type_row: int, | |||||
| node_type_column: int, | |||||
| is_symmetric: bool, | |||||
| decoder_class: Type) -> None: | |||||
| @dataclass | |||||
| class RelationFamilyBase(object): | |||||
| data: 'Data' | |||||
| name: str | |||||
| node_type_row: int | |||||
| node_type_column: int | |||||
| is_symmetric: bool | |||||
| decoder_class: Type | |||||
| if not is_symmetric and \ | |||||
| decoder_class != DEDICOMDecoder and \ | |||||
| decoder_class != BilinearDecoder: | |||||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
| self.data = data | |||||
| self.name = name | |||||
| self.node_type_row = node_type_row | |||||
| self.node_type_column = node_type_column | |||||
| self.is_symmetric = is_symmetric | |||||
| self.decoder_class = decoder_class | |||||
| @dataclass | |||||
| class RelationFamily(RelationFamilyBase): | |||||
| relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||||
| def __post_init__(self) -> None: | |||||
| if not self.is_symmetric and \ | |||||
| self.decoder_class != DEDICOMDecoder and \ | |||||
| self.decoder_class != BilinearDecoder: | |||||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
| self.relation_types = { (node_type_row, node_type_column): [], | |||||
| (node_type_column, node_type_row): [] } | |||||
| self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||||
| (self.node_type_column, self.node_type_row): [] } | |||||
| def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | ||||
| adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | ||||
| @@ -166,7 +169,7 @@ class RelationFamily(object): | |||||
| class Data(object): | class Data(object): | ||||
| node_types: List[NodeType] | node_types: List[NodeType] | ||||
| relation_types: Dict[Tuple[int, int], List[RelationType]] | |||||
| relation_families: List[RelationFamily] | |||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| self.node_types = [] | self.node_types = [] | ||||
| @@ -6,13 +6,17 @@ | |||||
| from .sampling import fixed_unigram_candidate_sampler | from .sampling import fixed_unigram_candidate_sampler | ||||
| import torch | import torch | ||||
| from dataclasses import dataclass | |||||
| from dataclasses import dataclass, \ | |||||
| field | |||||
| from typing import Any, \ | from typing import Any, \ | ||||
| List, \ | List, \ | ||||
| Tuple, \ | Tuple, \ | ||||
| Dict | Dict | ||||
| from .data import NodeType, \ | from .data import NodeType, \ | ||||
| RelationType, \ | RelationType, \ | ||||
| RelationTypeBase, \ | |||||
| RelationFamily, \ | |||||
| RelationFamilyBase, \ | |||||
| Data | Data | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from .normalize import norm_adj_mat_one_node_type, \ | from .normalize import norm_adj_mat_one_node_type, \ | ||||
| @@ -28,25 +32,20 @@ class TrainValTest(object): | |||||
| @dataclass | @dataclass | ||||
| class PreparedEdges(object): | |||||
| positive: TrainValTest | |||||
| negative: TrainValTest | |||||
| class PreparedRelationType(RelationTypeBase): | |||||
| edges_pos: TrainValTest | |||||
| edges_neg: TrainValTest | |||||
| @dataclass | @dataclass | ||||
| class PreparedRelationType(object): | |||||
| name: str | |||||
| node_type_row: int | |||||
| node_type_column: int | |||||
| adjacency_matrix: torch.Tensor | |||||
| edges_pos: TrainValTest | |||||
| edges_neg: TrainValTest | |||||
| class PreparedRelationFamily(RelationFamilyBase): | |||||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
| @dataclass | @dataclass | ||||
| class PreparedData(object): | class PreparedData(object): | ||||
| node_types: List[NodeType] | node_types: List[NodeType] | ||||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
| relation_families: List[PreparedRelationFamily] | |||||
| def train_val_test_split_edges(edges: torch.Tensor, | def train_val_test_split_edges(edges: torch.Tensor, | ||||
| @@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType, | |||||
| adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | ||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | ||||
| adj_mat_train, edges_pos, edges_neg) | |||||
| adj_mat_train, r.two_way, edges_pos, edges_neg) | |||||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
| if not isinstance(data, Data): | |||||
| raise ValueError('data must be of class Data') | |||||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||||
| relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||||
| (fam.node_type_column, fam.node_type_row): [] } | |||||
| relation_types = defaultdict(list) | |||||
| for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
| for r in rels: | for r in rels: | ||||
| relation_types[node_type_row, node_type_column].append( | relation_types[node_type_row, node_type_column].append( | ||||
| prepare_relation_type(r, ratios)) | prepare_relation_type(r, ratios)) | ||||
| return PreparedData(data.node_types, relation_types) | |||||
| return PreparedRelationFamily(fam.data, fam.name, | |||||
| fam.node_type_row, fam.node_type_column, | |||||
| fam.is_symmetric, fam.decoder_class, | |||||
| relation_types) | |||||
| def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
| if not isinstance(data, Data): | |||||
| raise ValueError('data must be of class Data') | |||||
| relation_families = [ prepare_relation_family(fam) \ | |||||
| for fam in data.relation_families ] | |||||
| return PreparedData(data.node_types, relation_families) | |||||
| @@ -17,7 +17,8 @@ import torch | |||||
| def test_decode_layer_01(): | def test_decode_layer_01(): | ||||
| d = Data() | d = Data() | ||||
| d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
| d.add_relation_type('Dummy Relation 1', 0, 0, | |||||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||||
| fam.add_relation_type('Dummy Relation 1', 0, 0, | |||||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | ||||
| prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | ||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| @@ -105,7 +105,7 @@ def test_prepare_adj_mat_02(): | |||||
| def test_prepare_relation_type_01(): | def test_prepare_relation_type_01(): | ||||
| adj_mat = (torch.rand((10, 10)) > .5) | adj_mat = (torch.rand((10, 10)) > .5) | ||||
| r = RelationType('Test', 0, 0, adj_mat) | |||||
| r = RelationType('Test', 0, 0, adj_mat, True) | |||||
| ratios = TrainValTest(.8, .1, .1) | ratios = TrainValTest(.8, .1, .1) | ||||
| _ = prepare_relation_type(r, ratios) | _ = prepare_relation_type(r, ratios) | ||||