@@ -44,38 +44,41 @@ class NodeType(object): | |||
@dataclass | |||
class RelationType(object): | |||
class RelationTypeBase(object): | |||
name: str | |||
node_type_row: int | |||
node_type_column: int | |||
adjacency_matrix: torch.Tensor | |||
two_way: bool | |||
@dataclass | |||
class RelationType(RelationTypeBase): | |||
hints: Dict[str, Any] = field(default_factory=dict) | |||
class RelationFamily(object): | |||
def __init__(self, | |||
data: 'Data', | |||
name: str, | |||
node_type_row: int, | |||
node_type_column: int, | |||
is_symmetric: bool, | |||
decoder_class: Type) -> None: | |||
@dataclass | |||
class RelationFamilyBase(object): | |||
data: 'Data' | |||
name: str | |||
node_type_row: int | |||
node_type_column: int | |||
is_symmetric: bool | |||
decoder_class: Type | |||
if not is_symmetric and \ | |||
decoder_class != DEDICOMDecoder and \ | |||
decoder_class != BilinearDecoder: | |||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
self.data = data | |||
self.name = name | |||
self.node_type_row = node_type_row | |||
self.node_type_column = node_type_column | |||
self.is_symmetric = is_symmetric | |||
self.decoder_class = decoder_class | |||
@dataclass | |||
class RelationFamily(RelationFamilyBase): | |||
relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||
def __post_init__(self) -> None: | |||
if not self.is_symmetric and \ | |||
self.decoder_class != DEDICOMDecoder and \ | |||
self.decoder_class != BilinearDecoder: | |||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
self.relation_types = { (node_type_row, node_type_column): [], | |||
(node_type_column, node_type_row): [] } | |||
self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||
(self.node_type_column, self.node_type_row): [] } | |||
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||
@@ -166,7 +169,7 @@ class RelationFamily(object): | |||
class Data(object): | |||
node_types: List[NodeType] | |||
relation_types: Dict[Tuple[int, int], List[RelationType]] | |||
relation_families: List[RelationFamily] | |||
def __init__(self) -> None: | |||
self.node_types = [] | |||
@@ -6,13 +6,17 @@ | |||
from .sampling import fixed_unigram_candidate_sampler | |||
import torch | |||
from dataclasses import dataclass | |||
from dataclasses import dataclass, \ | |||
field | |||
from typing import Any, \ | |||
List, \ | |||
Tuple, \ | |||
Dict | |||
from .data import NodeType, \ | |||
RelationType, \ | |||
RelationTypeBase, \ | |||
RelationFamily, \ | |||
RelationFamilyBase, \ | |||
Data | |||
from collections import defaultdict | |||
from .normalize import norm_adj_mat_one_node_type, \ | |||
@@ -28,25 +32,20 @@ class TrainValTest(object): | |||
@dataclass | |||
class PreparedEdges(object): | |||
positive: TrainValTest | |||
negative: TrainValTest | |||
class PreparedRelationType(RelationTypeBase): | |||
edges_pos: TrainValTest | |||
edges_neg: TrainValTest | |||
@dataclass | |||
class PreparedRelationType(object): | |||
name: str | |||
node_type_row: int | |||
node_type_column: int | |||
adjacency_matrix: torch.Tensor | |||
edges_pos: TrainValTest | |||
edges_neg: TrainValTest | |||
class PreparedRelationFamily(RelationFamilyBase): | |||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
@dataclass | |||
class PreparedData(object): | |||
node_types: List[NodeType] | |||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
relation_families: List[PreparedRelationFamily] | |||
def train_val_test_split_edges(edges: torch.Tensor, | |||
@@ -130,16 +129,29 @@ def prepare_relation_type(r: RelationType, | |||
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
adj_mat_train, edges_pos, edges_neg) | |||
adj_mat_train, r.two_way, edges_pos, edges_neg) | |||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
if not isinstance(data, Data): | |||
raise ValueError('data must be of class Data') | |||
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||
relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||
(fam.node_type_column, fam.node_type_row): [] } | |||
relation_types = defaultdict(list) | |||
for (node_type_row, node_type_column), rels in data.relation_types.items(): | |||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
for r in rels: | |||
relation_types[node_type_row, node_type_column].append( | |||
prepare_relation_type(r, ratios)) | |||
return PreparedData(data.node_types, relation_types) | |||
return PreparedRelationFamily(fam.data, fam.name, | |||
fam.node_type_row, fam.node_type_column, | |||
fam.is_symmetric, fam.decoder_class, | |||
relation_types) | |||
def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
if not isinstance(data, Data): | |||
raise ValueError('data must be of class Data') | |||
relation_families = [ prepare_relation_family(fam) \ | |||
for fam in data.relation_families ] | |||
return PreparedData(data.node_types, relation_families) |
@@ -17,7 +17,8 @@ import torch | |||
def test_decode_layer_01(): | |||
d = Data() | |||
d.add_node_type('Dummy', 100) | |||
d.add_relation_type('Dummy Relation 1', 0, 0, | |||
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
fam.add_relation_type('Dummy Relation 1', 0, 0, | |||
torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||
prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
in_layer = OneHotInputLayer(d) | |||
@@ -105,7 +105,7 @@ def test_prepare_adj_mat_02(): | |||
def test_prepare_relation_type_01(): | |||
adj_mat = (torch.rand((10, 10)) > .5) | |||
r = RelationType('Test', 0, 0, adj_mat) | |||
r = RelationType('Test', 0, 0, adj_mat, True) | |||
ratios = TrainValTest(.8, .1, .1) | |||
_ = prepare_relation_type(r, ratios) | |||