IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

53 行
1.7KB

  1. from decagon_pytorch.layer import InputLayer
  2. from decagon_pytorch.data import Data
  3. import torch
  4. import pytest
  5. def _some_data():
  6. d = Data()
  7. d.add_node_type('Gene', 1000)
  8. d.add_node_type('Drug', 100)
  9. d.add_relation_type('Target', 1, 0, None)
  10. d.add_relation_type('Interaction', 0, 0, None)
  11. d.add_relation_type('Side Effect: Nausea', 1, 1, None)
  12. d.add_relation_type('Side Effect: Infertility', 1, 1, None)
  13. d.add_relation_type('Side Effect: Death', 1, 1, None)
  14. return d
  15. def test_input_layer_01():
  16. d = _some_data()
  17. for dimensionality in [32, 64, 128]:
  18. layer = InputLayer(d, dimensionality)
  19. assert layer.dimensionality == dimensionality
  20. assert len(layer.node_reps) == 2
  21. assert layer.node_reps[0].shape == (1000, dimensionality)
  22. assert layer.node_reps[1].shape == (100, dimensionality)
  23. assert layer.data == d
  24. def test_input_layer_02():
  25. d = _some_data()
  26. layer = InputLayer(d, 32)
  27. res = layer()
  28. assert isinstance(res[0], torch.Tensor)
  29. assert isinstance(res[1], torch.Tensor)
  30. assert res[0].shape == (1000, 32)
  31. assert res[1].shape == (100, 32)
  32. assert torch.all(res[0] == layer.node_reps[0])
  33. assert torch.all(res[1] == layer.node_reps[1])
  34. def test_input_layer_03():
  35. if torch.cuda.device_count() == 0:
  36. pytest.skip('No CUDA devices on this host')
  37. d = _some_data()
  38. layer = InputLayer(d, 32)
  39. device = torch.device('cuda:0')
  40. layer = layer.to(device)
  41. print(list(layer.parameters()))
  42. # assert layer.device.type == 'cuda:0'
  43. assert layer.node_reps[0].device == device
  44. assert layer.node_reps[1].device == device