| @@ -15,6 +15,25 @@ from typing import Type, \ | |||
| Dict, \ | |||
| Tuple | |||
| from .decode import DEDICOMDecoder | |||
| from dataclasses import dataclass | |||
| @dataclass | |||
| class RelationPredictions(object): | |||
| edges_pos: TrainValTest | |||
| edges_neg: TrainValTest | |||
| edges_back_pos: TrainValTest | |||
| edges_back_neg: TrainValTest | |||
| @dataclass | |||
| class RelationFamilyPredictions(object): | |||
| relation_types: List[RelationPredictions] | |||
| @dataclass | |||
| class Predictions(object): | |||
| relation_families: List[RelationFamilyPredictions] | |||
| class DecodeLayer(torch.nn.Module): | |||
| @@ -30,13 +49,16 @@ class DecodeLayer(torch.nn.Module): | |||
| if not isinstance(input_dim, list): | |||
| raise TypeError('input_dim must be a List') | |||
| if len(input_dim) != len(data.node_types): | |||
| raise ValueError('input_dim must have length equal to num_node_types') | |||
| if not all([ a == input_dim[0] for a in input_dim ]): | |||
| raise ValueError('All elements of input_dim must have the same value') | |||
| if not isinstance(data, PreparedData): | |||
| raise TypeError('data must be an instance of PreparedData') | |||
| self.input_dim = input_dim | |||
| self.input_dim = input_dim[0] | |||
| self.output_dim = 1 | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| @@ -47,42 +69,38 @@ class DecodeLayer(torch.nn.Module): | |||
| def build(self) -> None: | |||
| self.decoders = [] | |||
| for fam in self.data.relation_families: | |||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
| for r in rels: | |||
| pass | |||
| dec = fam.decoder_class() | |||
| dec = fam.decoder_class(self.input_dim, len(fam.relation_types), | |||
| self.keep_prob, self.activation) | |||
| self.decoders.append(dec) | |||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
| if len(rels) == 0: | |||
| continue | |||
| if isinstance(self.decoder_class, dict): | |||
| if (node_type_row, node_type_column) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||
| else: | |||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||
| self.data.node_types[node_type_row].name, | |||
| self.data.node_types[node_type_column].name)) | |||
| else: | |||
| decoder_class = self.decoder_class | |||
| self.decoders[node_type_row, node_type_column] = \ | |||
| decoder_class(self.input_dim[node_type_row], | |||
| num_relation_types = len(rels), | |||
| keep_prob = self.keep_prob, | |||
| activation = self.activation) | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||
| res = {} | |||
| for (node_type_row, node_type_column), dec in self.decoders.items(): | |||
| inputs_row = last_layer_repr[node_type_row] | |||
| inputs_column = last_layer_repr[node_type_column] | |||
| pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ] | |||
| res[node_type_row, node_type_column] = pred_adj_matrices | |||
| def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | |||
| pred = [] | |||
| for p in edge_list_attr_names: | |||
| tvt = [] | |||
| for t in ['train', 'val', 'test']: | |||
| # print('r:', r) | |||
| edges = getattr(getattr(r, p), t) | |||
| inputs_row = last_layer_repr[row][edges[:, 0]] | |||
| inputs_column = last_layer_repr[column][edges[:, 1]] | |||
| tvt.append(dec(inputs_row, inputs_column, k)) | |||
| tvt = TrainValTest(*tvt) | |||
| pred.append(tvt) | |||
| return pred | |||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||
| res = [] | |||
| for i, fam in enumerate(self.data.relation_families): | |||
| fam_pred = [] | |||
| for k, r in enumerate(fam.relation_types): | |||
| pred = [] | |||
| pred += self._get_tvt(r, ['edges_pos', 'edges_neg'], | |||
| r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i]) | |||
| pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'], | |||
| r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i]) | |||
| pred = RelationPredictions(*pred) | |||
| fam_pred.append(pred) | |||
| fam_pred = RelationFamilyPredictions(fam_pred) | |||
| res.append(fam_pred) | |||
| res = Predictions(res) | |||
| return res | |||
| @@ -35,6 +35,8 @@ class TrainValTest(object): | |||
| class PreparedRelationType(RelationTypeBase): | |||
| edges_pos: TrainValTest | |||
| edges_neg: TrainValTest | |||
| edges_back_pos: TrainValTest | |||
| edges_back_neg: TrainValTest | |||
| @dataclass | |||
| @@ -48,6 +50,10 @@ class PreparedData(object): | |||
| relation_families: List[PreparedRelationFamily] | |||
| def _empty_edge_list_tvt() -> TrainValTest: | |||
| return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||
| def train_val_test_split_edges(edges: torch.Tensor, | |||
| ratios: TrainValTest) -> TrainValTest: | |||
| @@ -115,12 +121,15 @@ def prep_rel_one_node_type(r: RelationType, | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| print('adj_mat_train:', adj_mat_train) | |||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
| adj_mat_train, None, edges_pos, edges_neg) | |||
| adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||
| edges_back_pos, edges_back_neg) | |||
| def prep_rel_two_node_types_sym(r: RelationType, | |||
| @@ -128,12 +137,14 @@ def prep_rel_two_node_types_sym(r: RelationType, | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| edges_back_pos, edges_back_neg = \ | |||
| _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||
| edges_pos, edges_neg) | |||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
| def prep_rel_two_node_types_asym(r: RelationType, | |||
| @@ -144,23 +155,20 @@ def prep_rel_two_node_types_asym(r: RelationType, | |||
| prepare_adj_mat(r.adjacency_matrix, ratios) | |||
| else: | |||
| adj_mat_train, edges_pos, edges_neg = \ | |||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| if r.adjacency_matrix_backward is not None: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||
| else: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
| edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||
| edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_back_train), | |||
| edges_pos, edges_neg) | |||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||
| def prepare_relation_type(r: RelationType, | |||
| @@ -180,7 +188,9 @@ def prepare_relation_type(r: RelationType, | |||
| return prep_rel_two_node_types_asym(r, ratios) | |||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||
| def prepare_relation_family(fam: RelationFamily, | |||
| ratios: TrainValTest) -> PreparedRelationFamily: | |||
| relation_types = [] | |||
| for r in fam.relation_types: | |||
| @@ -196,7 +206,7 @@ def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||
| if not isinstance(data, Data): | |||
| raise ValueError('data must be of class Data') | |||
| relation_families = [ prepare_relation_family(fam) \ | |||
| relation_families = [ prepare_relation_family(fam, ratios) \ | |||
| for fam in data.relation_families ] | |||
| return PreparedData(data.node_types, relation_families) | |||
| @@ -6,7 +6,10 @@ | |||
| from icosagon.input import OneHotInputLayer | |||
| from icosagon.convlayer import DecagonLayer | |||
| from icosagon.declayer import DecodeLayer | |||
| from icosagon.declayer import DecodeLayer, \ | |||
| Predictions, \ | |||
| RelationFamilyPredictions, \ | |||
| RelationPredictions | |||
| from icosagon.decode import DEDICOMDecoder | |||
| from icosagon.data import Data | |||
| from icosagon.trainprep import prepare_training, \ | |||
| @@ -17,21 +20,36 @@ import torch | |||
| def test_decode_layer_01(): | |||
| d = Data() | |||
| d.add_node_type('Dummy', 100) | |||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
| fam.add_relation_type('Dummy Relation 1', 0, 0, | |||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | |||
| prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
| in_layer = OneHotInputLayer(d) | |||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d) | |||
| seq = torch.nn.Sequential(in_layer, d_layer) | |||
| last_layer_repr = seq(None) | |||
| dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., | |||
| decoder_class=DEDICOMDecoder, activation=lambda x: x) | |||
| pred_adj_matrices = dec(last_layer_repr) | |||
| assert isinstance(pred_adj_matrices, dict) | |||
| assert len(pred_adj_matrices) == 1 | |||
| assert isinstance(pred_adj_matrices[0, 0], list) | |||
| assert len(pred_adj_matrices[0, 0]) == 1 | |||
| activation=lambda x: x) | |||
| pred = dec(last_layer_repr) | |||
| assert isinstance(pred, Predictions) | |||
| assert isinstance(pred.relation_families, list) | |||
| assert len(pred.relation_families) == 1 | |||
| assert isinstance(pred.relation_families[0], RelationFamilyPredictions) | |||
| assert isinstance(pred.relation_families[0].relation_types, list) | |||
| assert len(pred.relation_families[0].relation_types) == 1 | |||
| assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions) | |||
| tmp = pred.relation_families[0].relation_types[0] | |||
| assert isinstance(tmp.edges_pos, TrainValTest) | |||
| assert isinstance(tmp.edges_neg, TrainValTest) | |||
| assert isinstance(tmp.edges_back_pos, TrainValTest) | |||
| assert isinstance(tmp.edges_back_neg, TrainValTest) | |||
| def test_decode_layer_02(): | |||