diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 7840ab5..db78f11 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -15,6 +15,25 @@ from typing import Type, \ Dict, \ Tuple from .decode import DEDICOMDecoder +from dataclasses import dataclass + + +@dataclass +class RelationPredictions(object): + edges_pos: TrainValTest + edges_neg: TrainValTest + edges_back_pos: TrainValTest + edges_back_neg: TrainValTest + + +@dataclass +class RelationFamilyPredictions(object): + relation_types: List[RelationPredictions] + + +@dataclass +class Predictions(object): + relation_families: List[RelationFamilyPredictions] class DecodeLayer(torch.nn.Module): @@ -30,13 +49,16 @@ class DecodeLayer(torch.nn.Module): if not isinstance(input_dim, list): raise TypeError('input_dim must be a List') + if len(input_dim) != len(data.node_types): + raise ValueError('input_dim must have length equal to num_node_types') + if not all([ a == input_dim[0] for a in input_dim ]): raise ValueError('All elements of input_dim must have the same value') if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') - self.input_dim = input_dim + self.input_dim = input_dim[0] self.output_dim = 1 self.data = data self.keep_prob = keep_prob @@ -47,42 +69,38 @@ class DecodeLayer(torch.nn.Module): def build(self) -> None: self.decoders = [] - for fam in self.data.relation_families: - for (node_type_row, node_type_column), rels in fam.relation_types.items(): - for r in rels: - pass - - dec = fam.decoder_class() + dec = fam.decoder_class(self.input_dim, len(fam.relation_types), + self.keep_prob, self.activation) self.decoders.append(dec) - for (node_type_row, node_type_column), rels in self.data.relation_types.items(): - if len(rels) == 0: - continue - - if isinstance(self.decoder_class, dict): - if (node_type_row, node_type_column) in self.decoder_class: - decoder_class = self.decoder_class[node_type_row, node_type_column] - elif (node_type_column, node_type_row) in self.decoder_class: - decoder_class = self.decoder_class[node_type_column, node_type_row] - else: - raise KeyError('Decoder not specified for edge type: %s -- %s' % ( - self.data.node_types[node_type_row].name, - self.data.node_types[node_type_column].name)) - else: - decoder_class = self.decoder_class - - self.decoders[node_type_row, node_type_column] = \ - decoder_class(self.input_dim[node_type_row], - num_relation_types = len(rels), - keep_prob = self.keep_prob, - activation = self.activation) - - def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: - res = {} - for (node_type_row, node_type_column), dec in self.decoders.items(): - inputs_row = last_layer_repr[node_type_row] - inputs_column = last_layer_repr[node_type_column] - pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ] - res[node_type_row, node_type_column] = pred_adj_matrices + def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): + pred = [] + for p in edge_list_attr_names: + tvt = [] + for t in ['train', 'val', 'test']: + # print('r:', r) + edges = getattr(getattr(r, p), t) + inputs_row = last_layer_repr[row][edges[:, 0]] + inputs_column = last_layer_repr[column][edges[:, 1]] + tvt.append(dec(inputs_row, inputs_column, k)) + tvt = TrainValTest(*tvt) + pred.append(tvt) + return pred + + def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: + res = [] + for i, fam in enumerate(self.data.relation_families): + fam_pred = [] + for k, r in enumerate(fam.relation_types): + pred = [] + pred += self._get_tvt(r, ['edges_pos', 'edges_neg'], + r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i]) + pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'], + r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i]) + pred = RelationPredictions(*pred) + fam_pred.append(pred) + fam_pred = RelationFamilyPredictions(fam_pred) + res.append(fam_pred) + res = Predictions(res) return res diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index 5c7843d..3457791 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -35,6 +35,8 @@ class TrainValTest(object): class PreparedRelationType(RelationTypeBase): edges_pos: TrainValTest edges_neg: TrainValTest + edges_back_pos: TrainValTest + edges_back_neg: TrainValTest @dataclass @@ -48,6 +50,10 @@ class PreparedData(object): relation_families: List[PreparedRelationFamily] +def _empty_edge_list_tvt() -> TrainValTest: + return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) + + def train_val_test_split_edges(edges: torch.Tensor, ratios: TrainValTest) -> TrainValTest: @@ -115,12 +121,15 @@ def prep_rel_one_node_type(r: RelationType, adj_mat = r.adjacency_matrix adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() print('adj_mat_train:', adj_mat_train) adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, - adj_mat_train, None, edges_pos, edges_neg) + adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, + edges_back_pos, edges_back_neg) def prep_rel_two_node_types_sym(r: RelationType, @@ -128,12 +137,14 @@ def prep_rel_two_node_types_sym(r: RelationType, adj_mat = r.adjacency_matrix adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + edges_back_pos, edges_back_neg = \ + _empty_edge_list_tvt(), _empty_edge_list_tvt() return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, norm_adj_mat_two_node_types(adj_mat_train), norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), - edges_pos, edges_neg) + edges_pos, edges_neg, edges_back_pos, edges_back_neg) def prep_rel_two_node_types_asym(r: RelationType, @@ -144,23 +155,20 @@ def prep_rel_two_node_types_asym(r: RelationType, prepare_adj_mat(r.adjacency_matrix, ratios) else: adj_mat_train, edges_pos, edges_neg = \ - None, torch.zeros((0, 2)), torch.zeros((0, 2)) + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() if r.adjacency_matrix_backward is not None: adj_mat_back_train, edges_back_pos, edges_back_neg = \ prepare_adj_mat(r.adjacency_matrix_backward, ratios) else: adj_mat_back_train, edges_back_pos, edges_back_neg = \ - None, torch.zeros((0, 2)), torch.zeros((0, 2)) - - edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) - edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) + None, _empty_edge_list_tvt(), _empty_edge_list_tvt() return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, norm_adj_mat_two_node_types(adj_mat_train), norm_adj_mat_two_node_types(adj_mat_back_train), - edges_pos, edges_neg) + edges_pos, edges_neg, edges_back_pos, edges_back_neg) def prepare_relation_type(r: RelationType, @@ -180,7 +188,9 @@ def prepare_relation_type(r: RelationType, return prep_rel_two_node_types_asym(r, ratios) -def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: +def prepare_relation_family(fam: RelationFamily, + ratios: TrainValTest) -> PreparedRelationFamily: + relation_types = [] for r in fam.relation_types: @@ -196,7 +206,7 @@ def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: if not isinstance(data, Data): raise ValueError('data must be of class Data') - relation_families = [ prepare_relation_family(fam) \ + relation_families = [ prepare_relation_family(fam, ratios) \ for fam in data.relation_families ] return PreparedData(data.node_types, relation_families) diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index 7da2ad1..bab9a25 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -6,7 +6,10 @@ from icosagon.input import OneHotInputLayer from icosagon.convlayer import DecagonLayer -from icosagon.declayer import DecodeLayer +from icosagon.declayer import DecodeLayer, \ + Predictions, \ + RelationFamilyPredictions, \ + RelationPredictions from icosagon.decode import DEDICOMDecoder from icosagon.data import Data from icosagon.trainprep import prepare_training, \ @@ -17,21 +20,36 @@ import torch def test_decode_layer_01(): d = Data() d.add_node_type('Dummy', 100) + fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) fam.add_relation_type('Dummy Relation 1', 0, 0, torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) + prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) in_layer = OneHotInputLayer(d) d_layer = DecagonLayer(in_layer.output_dim, 32, d) seq = torch.nn.Sequential(in_layer, d_layer) last_layer_repr = seq(None) + dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., - decoder_class=DEDICOMDecoder, activation=lambda x: x) - pred_adj_matrices = dec(last_layer_repr) - assert isinstance(pred_adj_matrices, dict) - assert len(pred_adj_matrices) == 1 - assert isinstance(pred_adj_matrices[0, 0], list) - assert len(pred_adj_matrices[0, 0]) == 1 + activation=lambda x: x) + pred = dec(last_layer_repr) + + assert isinstance(pred, Predictions) + + assert isinstance(pred.relation_families, list) + assert len(pred.relation_families) == 1 + assert isinstance(pred.relation_families[0], RelationFamilyPredictions) + + assert isinstance(pred.relation_families[0].relation_types, list) + assert len(pred.relation_families[0].relation_types) == 1 + assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions) + + tmp = pred.relation_families[0].relation_types[0] + assert isinstance(tmp.edges_pos, TrainValTest) + assert isinstance(tmp.edges_neg, TrainValTest) + assert isinstance(tmp.edges_back_pos, TrainValTest) + assert isinstance(tmp.edges_back_neg, TrainValTest) def test_decode_layer_02():