@@ -44,38 +44,41 @@ class NodeType(object): | |||||
@dataclass | @dataclass | ||||
class RelationType(object): | |||||
class RelationTypeBase(object): | |||||
name: str | name: str | ||||
node_type_row: int | node_type_row: int | ||||
node_type_column: int | node_type_column: int | ||||
adjacency_matrix: torch.Tensor | adjacency_matrix: torch.Tensor | ||||
two_way: bool | two_way: bool | ||||
@dataclass | |||||
class RelationType(RelationTypeBase): | |||||
hints: Dict[str, Any] = field(default_factory=dict) | hints: Dict[str, Any] = field(default_factory=dict) | ||||
class RelationFamily(object): | |||||
def __init__(self, | |||||
data: 'Data', | |||||
name: str, | |||||
node_type_row: int, | |||||
node_type_column: int, | |||||
is_symmetric: bool, | |||||
decoder_class: Type) -> None: | |||||
@dataclass | |||||
class RelationFamilyBase(object): | |||||
data: 'Data' | |||||
name: str | |||||
node_type_row: int | |||||
node_type_column: int | |||||
is_symmetric: bool | |||||
decoder_class: Type | |||||
if not is_symmetric and \ | |||||
decoder_class != DEDICOMDecoder and \ | |||||
decoder_class != BilinearDecoder: | |||||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
self.data = data | |||||
self.name = name | |||||
self.node_type_row = node_type_row | |||||
self.node_type_column = node_type_column | |||||
self.is_symmetric = is_symmetric | |||||
self.decoder_class = decoder_class | |||||
@dataclass | |||||
class RelationFamily(RelationFamilyBase): | |||||
relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||||
def __post_init__(self) -> None: | |||||
if not self.is_symmetric and \ | |||||
self.decoder_class != DEDICOMDecoder and \ | |||||
self.decoder_class != BilinearDecoder: | |||||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||||
self.relation_types = { (node_type_row, node_type_column): [], | |||||
(node_type_column, node_type_row): [] } | |||||
self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||||
(self.node_type_column, self.node_type_row): [] } | |||||
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | ||||
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | ||||
@@ -166,7 +169,7 @@ class RelationFamily(object): | |||||
class Data(object): | class Data(object): | ||||
node_types: List[NodeType] | node_types: List[NodeType] | ||||
relation_types: Dict[Tuple[int, int], List[RelationType]] | |||||
relation_families: List[RelationFamily] | |||||
def __init__(self) -> None: | def __init__(self) -> None: | ||||
self.node_types = [] | self.node_types = [] | ||||
@@ -6,13 +6,17 @@ | |||||
from .sampling import fixed_unigram_candidate_sampler | from .sampling import fixed_unigram_candidate_sampler | ||||
import torch | import torch | ||||
from dataclasses import dataclass | |||||
from dataclasses import dataclass, \ | |||||
field | |||||
from typing import Any, \ | from typing import Any, \ | ||||
List, \ | List, \ | ||||
Tuple, \ | Tuple, \ | ||||
Dict | Dict | ||||
from .data import NodeType, \ | from .data import NodeType, \ | ||||
RelationType, \ | RelationType, \ | ||||
RelationTypeBase, \ | |||||
RelationFamily, \ | |||||
RelationFamilyBase, \ | |||||
Data | Data | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from .normalize import norm_adj_mat_one_node_type, \ | from .normalize import norm_adj_mat_one_node_type, \ | ||||
@@ -28,25 +32,20 @@ class TrainValTest(object): | |||||
@dataclass | @dataclass | ||||
class PreparedEdges(object): | |||||
positive: TrainValTest | |||||
negative: TrainValTest | |||||
class PreparedRelationType(RelationTypeBase): | |||||
edges_pos: TrainValTest | |||||
edges_neg: TrainValTest | |||||
@dataclass | @dataclass | ||||
class PreparedRelationType(object): | |||||
name: str | |||||
node_type_row: int | |||||
node_type_column: int | |||||
adjacency_matrix: torch.Tensor | |||||
edges_pos: TrainValTest | |||||
edges_neg: TrainValTest | |||||
class PreparedRelationFamily(RelationFamilyBase): | |||||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
@dataclass | @dataclass | ||||
class PreparedData(object): | class PreparedData(object): | ||||
node_types: List[NodeType] | node_types: List[NodeType] | ||||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
relation_families: List[PreparedRelationFamily] | |||||
def train_val_test_split_edges(edges: torch.Tensor, | def train_val_test_split_edges(edges: torch.Tensor, | ||||
@@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType, | |||||
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | ||||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | ||||
adj_mat_train, edges_pos, edges_neg) | |||||
adj_mat_train, r.two_way, edges_pos, edges_neg) | |||||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
if not isinstance(data, Data): | |||||
raise ValueError('data must be of class Data') | |||||
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||||
relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||||
(fam.node_type_column, fam.node_type_row): [] } | |||||
relation_types = defaultdict(list) | |||||
for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
for r in rels: | for r in rels: | ||||
relation_types[node_type_row, node_type_column].append( | relation_types[node_type_row, node_type_column].append( | ||||
prepare_relation_type(r, ratios)) | prepare_relation_type(r, ratios)) | ||||
return PreparedData(data.node_types, relation_types) | |||||
return PreparedRelationFamily(fam.data, fam.name, | |||||
fam.node_type_row, fam.node_type_column, | |||||
fam.is_symmetric, fam.decoder_class, | |||||
relation_types) | |||||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
if not isinstance(data, Data): | |||||
raise ValueError('data must be of class Data') | |||||
relation_families = [ prepare_relation_family(fam) \ | |||||
for fam in data.relation_families ] | |||||
return PreparedData(data.node_types, relation_families) |
@@ -17,7 +17,8 @@ import torch | |||||
def test_decode_layer_01(): | def test_decode_layer_01(): | ||||
d = Data() | d = Data() | ||||
d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
d.add_relation_type('Dummy Relation 1', 0, 0, | |||||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||||
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | ||||
prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | ||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
@@ -105,7 +105,7 @@ def test_prepare_adj_mat_02(): | |||||
def test_prepare_relation_type_01(): | def test_prepare_relation_type_01(): | ||||
adj_mat = (torch.rand((10, 10)) > .5) | adj_mat = (torch.rand((10, 10)) > .5) | ||||
r = RelationType('Test', 0, 0, adj_mat) | |||||
r = RelationType('Test', 0, 0, adj_mat, True) | |||||
ratios = TrainValTest(.8, .1, .1) | ratios = TrainValTest(.8, .1, .1) | ||||
_ = prepare_relation_type(r, ratios) | _ = prepare_relation_type(r, ratios) | ||||