|
- from .data import Data, \
- TrainingBatch, \
- EdgeType
- from typing import Tuple
- from .util import _sparse_coo_tensor
-
-
- def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
- indices = adj_mat.indices()
- values = adj_mat.values()
-
- order = torch.randperm(indices.shape[1])
-
- indices = indices[:, order]
- values = values[order]
-
- ofs = 0
- res = []
- for r in ratios:
- cnt = r * len(values)
- ind = indices[:, ofs:ofs+cnt]
- val = values[ofs:ofs+cnt]
- res.append(_sparse_coo_tensor(ind, val, adj_mat.shape))
- ofs += cnt
-
- return res
-
-
- def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
- res = [ [] for _ in range(len(et.adjacency_matrices)) ]
-
- for adj_mat in et.adjacency_matrices:
- for i, new_adj_mat in enumerate(split_adj_mat(adj_mat, ratios)):
- res[i].append(new_adj_mat)
-
- return res
-
-
- def split_data(data: Data,
- ratios: List[float]):
-
- if not isinstance(data, Data):
- raise TypeError('data must be an instance of Data')
-
- ratios = list(ratios)
-
- if sum(ratios) != 1:
- raise ValueError('ratios must sum to 1')
-
- res = [ {} for _ in range(len(ratios)) ]
-
- for key, et in data.edge_types:
- for i, new_et in enumerate(split_edge_type(et, ratios)):
- res[i][key] = new_et
-
- res = [ Data(data.vertex_types, new_edge_types) \
- for new_edge_types in res ]
-
- return res
|