|
- from .data import Data, \
- EdgeType
- from typing import Tuple, \
- List
- from .util import _sparse_coo_tensor
- import torch
-
-
- def split_adj_mat(adj_mat: torch.Tensor, ratios: List[float]):
- ratios = list(ratios)
- if sum(ratios) != 1:
- raise ValueError('Sum of ratios must be 1')
-
- indices = adj_mat.indices()
- values = adj_mat.values()
-
- order = torch.randperm(indices.shape[1])
-
- indices = indices[:, order]
- values = values[order]
-
- ofs = 0
- res = []
- for r in ratios:
- # cnt = r * len(values)
-
- beg = int(ofs * len(values))
- end = int((ofs + r) * len(values))
- ofs += r
-
- ind = indices[:, beg:end]
- val = values[beg:end]
- res.append(_sparse_coo_tensor(ind, val, adj_mat.shape).coalesce())
- # ofs += cnt
-
- return res
-
-
- def split_edge_type(et: EdgeType, ratios: Tuple[float, float, float]):
- ratios = list(ratios)
- if sum(ratios) != 1:
- raise ValueError('Sum of ratios must be 1')
-
- res = [ split_adj_mat(adj_mat, ratios) \
- for adj_mat in et.adjacency_matrices ]
-
- res = [ EdgeType(et.name,
- et.vertex_type_row,
- et.vertex_type_column,
- [ a[i] for a in res ],
- et.decoder_factory,
- None ) for i in range(len(ratios)) ]
-
- return res
-
-
- def split_data(data: Data,
- ratios: List[float]):
-
- if not isinstance(data, Data):
- raise TypeError('data must be an instance of Data')
-
- ratios = list(ratios)
-
- if sum(ratios) != 1:
- raise ValueError('ratios must sum to 1')
-
- res = [ {} for _ in range(len(ratios)) ]
-
- for key, et in data.edge_types.items():
- for i, new_et in enumerate(split_edge_type(et, ratios)):
- res[i][key] = new_et
-
- res_1 = []
- for new_edge_types in res:
- d = Data()
- d.vertex_types = data.vertex_types
- d.edge_types = new_edge_types
- res_1.append(d)
-
- return res_1
|