|
- #
- # Copyright (C) Stanislaw Adaszewski, 2020
- # License: GPLv3
- #
-
-
- from .sampling import fixed_unigram_candidate_sampler
- import torch
- from dataclasses import dataclass
- from typing import Any, \
- List, \
- Tuple, \
- Dict
- from .data import NodeType
- from collections import defaultdict
-
-
- @dataclass
- class TrainValTest(object):
- train: Any
- val: Any
- test: Any
-
-
- @dataclass
- class PreparedEdges(object):
- positive: TrainValTest
- negative: TrainValTest
-
-
- @dataclass
- class PreparedRelationType(object):
- name: str
- node_type_row: int
- node_type_column: int
- adj_mat_train: torch.Tensor
- edges_pos: TrainValTest
- edges_neg: TrainValTest
-
-
- @dataclass
- class PreparedData(object):
- node_types: List[NodeType]
- relation_types: Dict[int, Dict[int, List[PreparedRelationType]]]
-
-
- def train_val_test_split_edges(edges: torch.Tensor,
- ratios: TrainValTest) -> TrainValTest:
-
- if not isinstance(edges, torch.Tensor):
- raise ValueError('edges must be a torch.Tensor')
-
- if len(edges.shape) != 2 or edges.shape[1] != 2:
- raise ValueError('edges shape must be (num_edges, 2)')
-
- if not isinstance(ratios, TrainValTest):
- raise ValueError('ratios must be a TrainValTest')
-
- if ratios.train + ratios.val + ratios.test != 1.0:
- raise ValueError('Train, validation and test ratios must add up to 1')
-
- order = torch.randperm(len(edges))
- edges = edges[order, :]
- n = round(len(edges) * ratios.train)
- edges_train = edges[:n]
- n_1 = round(len(edges) * (ratios.train + ratios.val))
- edges_val = edges[n:n_1]
- edges_test = edges[n_1:]
-
- return TrainValTest(edges_train, edges_val, edges_test)
-
-
- def prepare_adj_mat(adj_mat: torch.Tensor,
- ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
-
- degrees = adj_mat.sum(0)
- edges_pos = torch.nonzero(adj_mat)
-
- neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1],
- len(edges), degrees, 0.75)
- edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1)
-
- edges_pos = train_val_test_split_edges(edges_pos, ratios)
- edges_neg = train_val_test_split_edges(edges_neg, ratios)
-
- return edges_pos, edges_neg
-
-
- def prepare_relation(r, ratios):
- adj_mat = r.adjacency_matrix
- edges_pos, edges_neg = prepare_adj_mat(adj_mat)
-
- adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1),
- values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype))
-
- return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
- adj_mat_train, edges_pos, edges_neg)
-
-
- def prepare_training(data):
- relation_types = defaultdict(lambda: defaultdict(list))
- for (node_type_row, node_type_column), rels in data.relation_types:
- for r in rels:
- relation_types[node_type_row][node_type_column].append(
- prep_relation(r))
- return PreparedData(data.node_types, relation_types)
|