from .data import Data, \ TrainingBatch, \ EdgeType from typing import Tuple from .util import _sparse_coo_tensor def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): indices = adj_mat.indices() values = adj_mat.values() order = torch.randperm(indices.shape[1]) indices = indices[:, order] values = values[order] ofs = 0 res = [] for r in ratios: cnt = r * len(values) ind = indices[:, ofs:ofs+cnt] val = values[ofs:ofs+cnt] res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) ofs += cnt return res def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): res = [ [] for _ in range(len(et.adjacency_matrices)) ] for adj_mat in et.adjacency_matrices: for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)): res[i].append(new_adj_mat) return res def split_data(data: Data, ratios: List[float]): if not isinstance(data, Data): raise TypeError('data must be an instance of Data') ratios = list(ratios) if sum(ratios) != 1: raise ValueError('ratios must sum to 1') res = [ {} for _ in range(len(ratios)) ] for key, et in data.edge_types: for i, new_et in enumerate(split_edge_type(et, ratios)): res[i][key] = new_et res = [ Data(data.vertex_types, new_edge_types) \ for new_edge_types in res ] return res