| @@ -15,6 +15,25 @@ from typing import Type, \ | |||||
| Dict, \ | Dict, \ | ||||
| Tuple | Tuple | ||||
| from .decode import DEDICOMDecoder | from .decode import DEDICOMDecoder | ||||
| from dataclasses import dataclass | |||||
| @dataclass | |||||
| class RelationPredictions(object): | |||||
| edges_pos: TrainValTest | |||||
| edges_neg: TrainValTest | |||||
| edges_back_pos: TrainValTest | |||||
| edges_back_neg: TrainValTest | |||||
| @dataclass | |||||
| class RelationFamilyPredictions(object): | |||||
| relation_types: List[RelationPredictions] | |||||
| @dataclass | |||||
| class Predictions(object): | |||||
| relation_families: List[RelationFamilyPredictions] | |||||
| class DecodeLayer(torch.nn.Module): | class DecodeLayer(torch.nn.Module): | ||||
| @@ -30,13 +49,16 @@ class DecodeLayer(torch.nn.Module): | |||||
| if not isinstance(input_dim, list): | if not isinstance(input_dim, list): | ||||
| raise TypeError('input_dim must be a List') | raise TypeError('input_dim must be a List') | ||||
| if len(input_dim) != len(data.node_types): | |||||
| raise ValueError('input_dim must have length equal to num_node_types') | |||||
| if not all([ a == input_dim[0] for a in input_dim ]): | if not all([ a == input_dim[0] for a in input_dim ]): | ||||
| raise ValueError('All elements of input_dim must have the same value') | raise ValueError('All elements of input_dim must have the same value') | ||||
| if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||
| raise TypeError('data must be an instance of PreparedData') | raise TypeError('data must be an instance of PreparedData') | ||||
| self.input_dim = input_dim | |||||
| self.input_dim = input_dim[0] | |||||
| self.output_dim = 1 | self.output_dim = 1 | ||||
| self.data = data | self.data = data | ||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| @@ -47,42 +69,38 @@ class DecodeLayer(torch.nn.Module): | |||||
| def build(self) -> None: | def build(self) -> None: | ||||
| self.decoders = [] | self.decoders = [] | ||||
| for fam in self.data.relation_families: | for fam in self.data.relation_families: | ||||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
| for r in rels: | |||||
| pass | |||||
| dec = fam.decoder_class() | |||||
| dec = fam.decoder_class(self.input_dim, len(fam.relation_types), | |||||
| self.keep_prob, self.activation) | |||||
| self.decoders.append(dec) | self.decoders.append(dec) | ||||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||||
| if len(rels) == 0: | |||||
| continue | |||||
| if isinstance(self.decoder_class, dict): | |||||
| if (node_type_row, node_type_column) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_row, node_type_column] | |||||
| elif (node_type_column, node_type_row) in self.decoder_class: | |||||
| decoder_class = self.decoder_class[node_type_column, node_type_row] | |||||
| else: | |||||
| raise KeyError('Decoder not specified for edge type: %s -- %s' % ( | |||||
| self.data.node_types[node_type_row].name, | |||||
| self.data.node_types[node_type_column].name)) | |||||
| else: | |||||
| decoder_class = self.decoder_class | |||||
| self.decoders[node_type_row, node_type_column] = \ | |||||
| decoder_class(self.input_dim[node_type_row], | |||||
| num_relation_types = len(rels), | |||||
| keep_prob = self.keep_prob, | |||||
| activation = self.activation) | |||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]: | |||||
| res = {} | |||||
| for (node_type_row, node_type_column), dec in self.decoders.items(): | |||||
| inputs_row = last_layer_repr[node_type_row] | |||||
| inputs_column = last_layer_repr[node_type_column] | |||||
| pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ] | |||||
| res[node_type_row, node_type_column] = pred_adj_matrices | |||||
| def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec): | |||||
| pred = [] | |||||
| for p in edge_list_attr_names: | |||||
| tvt = [] | |||||
| for t in ['train', 'val', 'test']: | |||||
| # print('r:', r) | |||||
| edges = getattr(getattr(r, p), t) | |||||
| inputs_row = last_layer_repr[row][edges[:, 0]] | |||||
| inputs_column = last_layer_repr[column][edges[:, 1]] | |||||
| tvt.append(dec(inputs_row, inputs_column, k)) | |||||
| tvt = TrainValTest(*tvt) | |||||
| pred.append(tvt) | |||||
| return pred | |||||
| def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]: | |||||
| res = [] | |||||
| for i, fam in enumerate(self.data.relation_families): | |||||
| fam_pred = [] | |||||
| for k, r in enumerate(fam.relation_types): | |||||
| pred = [] | |||||
| pred += self._get_tvt(r, ['edges_pos', 'edges_neg'], | |||||
| r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i]) | |||||
| pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'], | |||||
| r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i]) | |||||
| pred = RelationPredictions(*pred) | |||||
| fam_pred.append(pred) | |||||
| fam_pred = RelationFamilyPredictions(fam_pred) | |||||
| res.append(fam_pred) | |||||
| res = Predictions(res) | |||||
| return res | return res | ||||
| @@ -35,6 +35,8 @@ class TrainValTest(object): | |||||
| class PreparedRelationType(RelationTypeBase): | class PreparedRelationType(RelationTypeBase): | ||||
| edges_pos: TrainValTest | edges_pos: TrainValTest | ||||
| edges_neg: TrainValTest | edges_neg: TrainValTest | ||||
| edges_back_pos: TrainValTest | |||||
| edges_back_neg: TrainValTest | |||||
| @dataclass | @dataclass | ||||
| @@ -48,6 +50,10 @@ class PreparedData(object): | |||||
| relation_families: List[PreparedRelationFamily] | relation_families: List[PreparedRelationFamily] | ||||
| def _empty_edge_list_tvt() -> TrainValTest: | |||||
| return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ]) | |||||
| def train_val_test_split_edges(edges: torch.Tensor, | def train_val_test_split_edges(edges: torch.Tensor, | ||||
| ratios: TrainValTest) -> TrainValTest: | ratios: TrainValTest) -> TrainValTest: | ||||
| @@ -115,12 +121,15 @@ def prep_rel_one_node_type(r: RelationType, | |||||
| adj_mat = r.adjacency_matrix | adj_mat = r.adjacency_matrix | ||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | ||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| print('adj_mat_train:', adj_mat_train) | print('adj_mat_train:', adj_mat_train) | ||||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | ||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | ||||
| adj_mat_train, None, edges_pos, edges_neg) | |||||
| adj_mat_train, adj_mat_back_train, edges_pos, edges_neg, | |||||
| edges_back_pos, edges_back_neg) | |||||
| def prep_rel_two_node_types_sym(r: RelationType, | def prep_rel_two_node_types_sym(r: RelationType, | ||||
| @@ -128,12 +137,14 @@ def prep_rel_two_node_types_sym(r: RelationType, | |||||
| adj_mat = r.adjacency_matrix | adj_mat = r.adjacency_matrix | ||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | ||||
| edges_back_pos, edges_back_neg = \ | |||||
| _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| return PreparedRelationType(r.name, r.node_type_row, | return PreparedRelationType(r.name, r.node_type_row, | ||||
| r.node_type_column, | r.node_type_column, | ||||
| norm_adj_mat_two_node_types(adj_mat_train), | norm_adj_mat_two_node_types(adj_mat_train), | ||||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | ||||
| edges_pos, edges_neg) | |||||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
| def prep_rel_two_node_types_asym(r: RelationType, | def prep_rel_two_node_types_asym(r: RelationType, | ||||
| @@ -144,23 +155,20 @@ def prep_rel_two_node_types_asym(r: RelationType, | |||||
| prepare_adj_mat(r.adjacency_matrix, ratios) | prepare_adj_mat(r.adjacency_matrix, ratios) | ||||
| else: | else: | ||||
| adj_mat_train, edges_pos, edges_neg = \ | adj_mat_train, edges_pos, edges_neg = \ | ||||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| if r.adjacency_matrix_backward is not None: | if r.adjacency_matrix_backward is not None: | ||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | adj_mat_back_train, edges_back_pos, edges_back_neg = \ | ||||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | prepare_adj_mat(r.adjacency_matrix_backward, ratios) | ||||
| else: | else: | ||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | adj_mat_back_train, edges_back_pos, edges_back_neg = \ | ||||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
| edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||||
| edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||||
| None, _empty_edge_list_tvt(), _empty_edge_list_tvt() | |||||
| return PreparedRelationType(r.name, r.node_type_row, | return PreparedRelationType(r.name, r.node_type_row, | ||||
| r.node_type_column, | r.node_type_column, | ||||
| norm_adj_mat_two_node_types(adj_mat_train), | norm_adj_mat_two_node_types(adj_mat_train), | ||||
| norm_adj_mat_two_node_types(adj_mat_back_train), | norm_adj_mat_two_node_types(adj_mat_back_train), | ||||
| edges_pos, edges_neg) | |||||
| edges_pos, edges_neg, edges_back_pos, edges_back_neg) | |||||
| def prepare_relation_type(r: RelationType, | def prepare_relation_type(r: RelationType, | ||||
| @@ -180,7 +188,9 @@ def prepare_relation_type(r: RelationType, | |||||
| return prep_rel_two_node_types_asym(r, ratios) | return prep_rel_two_node_types_asym(r, ratios) | ||||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||||
| def prepare_relation_family(fam: RelationFamily, | |||||
| ratios: TrainValTest) -> PreparedRelationFamily: | |||||
| relation_types = [] | relation_types = [] | ||||
| for r in fam.relation_types: | for r in fam.relation_types: | ||||
| @@ -196,7 +206,7 @@ def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData: | |||||
| if not isinstance(data, Data): | if not isinstance(data, Data): | ||||
| raise ValueError('data must be of class Data') | raise ValueError('data must be of class Data') | ||||
| relation_families = [ prepare_relation_family(fam) \ | |||||
| relation_families = [ prepare_relation_family(fam, ratios) \ | |||||
| for fam in data.relation_families ] | for fam in data.relation_families ] | ||||
| return PreparedData(data.node_types, relation_families) | return PreparedData(data.node_types, relation_families) | ||||
| @@ -6,7 +6,10 @@ | |||||
| from icosagon.input import OneHotInputLayer | from icosagon.input import OneHotInputLayer | ||||
| from icosagon.convlayer import DecagonLayer | from icosagon.convlayer import DecagonLayer | ||||
| from icosagon.declayer import DecodeLayer | |||||
| from icosagon.declayer import DecodeLayer, \ | |||||
| Predictions, \ | |||||
| RelationFamilyPredictions, \ | |||||
| RelationPredictions | |||||
| from icosagon.decode import DEDICOMDecoder | from icosagon.decode import DEDICOMDecoder | ||||
| from icosagon.data import Data | from icosagon.data import Data | ||||
| from icosagon.trainprep import prepare_training, \ | from icosagon.trainprep import prepare_training, \ | ||||
| @@ -17,21 +20,36 @@ import torch | |||||
| def test_decode_layer_01(): | def test_decode_layer_01(): | ||||
| d = Data() | d = Data() | ||||
| d.add_node_type('Dummy', 100) | d.add_node_type('Dummy', 100) | ||||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | ||||
| fam.add_relation_type('Dummy Relation 1', 0, 0, | fam.add_relation_type('Dummy Relation 1', 0, 0, | ||||
| torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | torch.rand((100, 100), dtype=torch.float32).round().to_sparse()) | ||||
| prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | ||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d) | d_layer = DecagonLayer(in_layer.output_dim, 32, d) | ||||
| seq = torch.nn.Sequential(in_layer, d_layer) | seq = torch.nn.Sequential(in_layer, d_layer) | ||||
| last_layer_repr = seq(None) | last_layer_repr = seq(None) | ||||
| dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., | dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1., | ||||
| decoder_class=DEDICOMDecoder, activation=lambda x: x) | |||||
| pred_adj_matrices = dec(last_layer_repr) | |||||
| assert isinstance(pred_adj_matrices, dict) | |||||
| assert len(pred_adj_matrices) == 1 | |||||
| assert isinstance(pred_adj_matrices[0, 0], list) | |||||
| assert len(pred_adj_matrices[0, 0]) == 1 | |||||
| activation=lambda x: x) | |||||
| pred = dec(last_layer_repr) | |||||
| assert isinstance(pred, Predictions) | |||||
| assert isinstance(pred.relation_families, list) | |||||
| assert len(pred.relation_families) == 1 | |||||
| assert isinstance(pred.relation_families[0], RelationFamilyPredictions) | |||||
| assert isinstance(pred.relation_families[0].relation_types, list) | |||||
| assert len(pred.relation_families[0].relation_types) == 1 | |||||
| assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions) | |||||
| tmp = pred.relation_families[0].relation_types[0] | |||||
| assert isinstance(tmp.edges_pos, TrainValTest) | |||||
| assert isinstance(tmp.edges_neg, TrainValTest) | |||||
| assert isinstance(tmp.edges_back_pos, TrainValTest) | |||||
| assert isinstance(tmp.edges_back_neg, TrainValTest) | |||||
| def test_decode_layer_02(): | def test_decode_layer_02(): | ||||