from .data import Data, \ EdgeType from typing import Tuple, \ List from .util import _sparse_coo_tensor import torch def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]): indices = adj_mat.indices() values = adj_mat.values() order = torch.randperm(indices.shape[1]) indices = indices[:, order] values = values[order] ofs = 0 res = [] for r in ratios: # cnt = r * len(values) beg = int(ofs * len(values)) end = int((ofs + r) * len(values)) ofs += r ind = indices[:, beg:end] val = values[beg:end] res.append(_sparse_coo_tensor(ind, val, adj_mat.shape)) # ofs += cnt return res def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]): res = [ split_adj_mat(adj_mat, ratios) \ for adj_mat in et.adjacency_matrices ] res = [ EdgeType(et.name, et.vertex_type_row, et.vertex_type_column, [ a[i] for a in res ], et.decoder_factory, None ) for i in range(len(ratios)) ] return res def split_data(data: Data, ratios: List[float]): if not isinstance(data, Data): raise TypeError('data must be an instance of Data') ratios = list(ratios) if sum(ratios) != 1: raise ValueError('ratios must sum to 1') res = [ {} for _ in range(len(ratios)) ] for key, et in data.edge_types.items(): for i, new_et in enumerate(split_edge_type(et, ratios)): res[i][key] = new_et res_1 = [] for new_edge_types in res: d = Data() d.vertex_types = data.vertex_types, d.edge_types = new_edge_types res_1.append(d) return res_1