# # Copyright (C) Stanislaw Adaszewski, 2020 # License: GPLv3 # from .sampling import fixed_unigram_candidate_sampler import torch from dataclasses import dataclass from typing import Any, \ List, \ Tuple, \ Dict from .data import NodeType from collections import defaultdict @dataclass class TrainValTest(object): train: Any val: Any test: Any @dataclass class PreparedEdges(object): positive: TrainValTest negative: TrainValTest @dataclass class PreparedRelationType(object): name: str node_type_row: int node_type_column: int adj_mat_train: torch.Tensor edges_pos: TrainValTest edges_neg: TrainValTest @dataclass class PreparedData(object): node_types: List[NodeType] relation_types: Dict[int, Dict[int, List[PreparedRelationType]]] def train_val_test_split_edges(edges: torch.Tensor, ratios: TrainValTest) -> TrainValTest: if not isinstance(edges, torch.Tensor): raise ValueError('edges must be a torch.Tensor') if len(edges.shape) != 2 or edges.shape[1] != 2: raise ValueError('edges shape must be (num_edges, 2)') if not isinstance(ratios, TrainValTest): raise ValueError('ratios must be a TrainValTest') if ratios.train + ratios.val + ratios.test != 1.0: raise ValueError('Train, validation and test ratios must add up to 1') order = torch.randperm(len(edges)) edges = edges[order, :] n = round(len(edges) * ratios.train) edges_train = edges[:n] n_1 = round(len(edges) * (ratios.train + ratios.val)) edges_val = edges[n:n_1] edges_test = edges[n_1:] return TrainValTest(edges_train, edges_val, edges_test) def prepare_adj_mat(adj_mat: torch.Tensor, ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: degrees = adj_mat.sum(0) edges_pos = torch.nonzero(adj_mat) neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1], len(edges), degrees, 0.75) edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1) edges_pos = train_val_test_split_edges(edges_pos, ratios) edges_neg = train_val_test_split_edges(edges_neg, ratios) return edges_pos, edges_neg def prepare_relation(r, ratios): adj_mat = r.adjacency_matrix edges_pos, edges_neg = prepare_adj_mat(adj_mat) adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1), values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype)) return PreparedRelation(r.name, r.node_type_row, r.node_type_column, adj_mat_train, edges_pos, edges_neg) def prepare_training(data): relation_types = defaultdict(lambda: defaultdict(list)) for (node_type_row, node_type_column), rels in data.relation_types: for r in rels: relation_types[node_type_row][node_type_column].append( prep_relation(r)) return PreparedData(data.node_types, relation_types)