IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

50 行
1.2KB

  1. #
  2. # Copyright (C) Stanislaw Adaszewski, 2020
  3. # License: GPLv3
  4. #
  5. import scipy.sparse as sp
  6. class Batch(object):
  7. def __init__(self, adjacency_matrix):
  8. pass
  9. def get(size):
  10. pass
  11. def train_test_split(data, train_size=.8):
  12. pass
  13. class Minibatch(object):
  14. def __init__(self, data, node_type_row, node_type_column, size):
  15. self.data = data
  16. self.adjacency_matrix = data.get_adjacency_matrix(node_type_row, node_type_column)
  17. self.size = size
  18. self.order = np.random.permutation(adjacency_matrix.nnz)
  19. self.count = 0
  20. def reset(self):
  21. self.count = 0
  22. self.order = np.random.permutation(adjacency_matrix.nnz)
  23. def __iter__(self):
  24. adj_mat = self.adjacency_matrix
  25. size = self.size
  26. order = np.random.permutation(adj_mat.nnz)
  27. for i in range(0, len(order), size):
  28. row = adj_mat.row[i:i + size]
  29. col = adj_mat.col[i:i + size]
  30. data = adj_mat.data[i:i + size]
  31. adj_mat_batch = sp.coo_matrix((data, (row, col)), shape=adj_mat.shape)
  32. yield adj_mat_batch
  33. degree = self.adjacency_matrix.sum(1)
  34. def __len__(self):
  35. pass