IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_input.py 3.4KB

4 jaren geleden
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from icosagon.input import InputLayer, \
  2. OneHotInputLayer
  3. from icosagon.data import Data
  4. import torch
  5. import pytest
  6. def _some_data():
  7. d = Data()
  8. d.add_node_type('Gene', 1000)
  9. d.add_node_type('Drug', 100)
  10. fam = d.add_relation_family('Drug-Gene', 1, 0, True)
  11. fam.add_relation_type('Target', torch.rand(100, 1000))
  12. fam = d.add_relation_family('Gene-Gene', 0, 0, False)
  13. fam.add_relation_type('Interaction', torch.rand(1000, 1000))
  14. fam = d.add_relation_family('Drug-Drug', 1, 1, False)
  15. fam.add_relation_type('Side Effect: Nausea', torch.rand(100, 100))
  16. fam.add_relation_type('Side Effect: Infertility', torch.rand(100, 100))
  17. fam.add_relation_type('Side Effect: Death', torch.rand(100, 100))
  18. return d
  19. def test_input_layer_01():
  20. d = _some_data()
  21. for output_dim in [32, 64, 128]:
  22. layer = InputLayer(d, output_dim)
  23. assert layer.output_dim[0] == output_dim
  24. assert len(layer.node_reps) == 2
  25. assert layer.node_reps[0].shape == (1000, output_dim)
  26. assert layer.node_reps[1].shape == (100, output_dim)
  27. assert layer.data == d
  28. def test_input_layer_02():
  29. d = _some_data()
  30. layer = InputLayer(d, 32)
  31. res = layer(None)
  32. assert isinstance(res[0], torch.Tensor)
  33. assert isinstance(res[1], torch.Tensor)
  34. assert res[0].shape == (1000, 32)
  35. assert res[1].shape == (100, 32)
  36. assert torch.all(res[0] == layer.node_reps[0])
  37. assert torch.all(res[1] == layer.node_reps[1])
  38. def test_input_layer_03():
  39. if torch.cuda.device_count() == 0:
  40. pytest.skip('No CUDA devices on this host')
  41. d = _some_data()
  42. layer = InputLayer(d, 32)
  43. device = torch.device('cuda:0')
  44. layer = layer.to(device)
  45. print(list(layer.parameters()))
  46. # assert layer.device.type == 'cuda:0'
  47. assert layer.node_reps[0].device == device
  48. assert layer.node_reps[1].device == device
  49. def test_input_layer_04():
  50. d = _some_data()
  51. layer = InputLayer(d, 32)
  52. s = repr(layer)
  53. assert s.startswith('Icosagon input layer')
  54. def test_one_hot_input_layer_01():
  55. d = _some_data()
  56. layer = OneHotInputLayer(d)
  57. assert layer.output_dim == [1000, 100]
  58. assert len(layer.node_reps) == 2
  59. assert layer.node_reps[0].shape == (1000, 1000)
  60. assert layer.node_reps[1].shape == (100, 100)
  61. assert layer.data == d
  62. assert layer.is_sparse
  63. def test_one_hot_input_layer_02():
  64. d = _some_data()
  65. layer = OneHotInputLayer(d)
  66. res = layer(None)
  67. assert isinstance(res[0], torch.Tensor)
  68. assert isinstance(res[1], torch.Tensor)
  69. assert res[0].shape == (1000, 1000)
  70. assert res[1].shape == (100, 100)
  71. assert torch.all(res[0].to_dense() == layer.node_reps[0].to_dense())
  72. assert torch.all(res[1].to_dense() == layer.node_reps[1].to_dense())
  73. def test_one_hot_input_layer_03():
  74. if torch.cuda.device_count() == 0:
  75. pytest.skip('No CUDA devices on this host')
  76. d = _some_data()
  77. layer = OneHotInputLayer(d)
  78. device = torch.device('cuda:0')
  79. layer = layer.to(device)
  80. print(list(layer.parameters()))
  81. # assert layer.device.type == 'cuda:0'
  82. assert layer.node_reps[0].device == device
  83. assert layer.node_reps[1].device == device
  84. def test_one_hot_input_layer_04():
  85. d = _some_data()
  86. layer = OneHotInputLayer(d)
  87. s = repr(layer)
  88. assert s.startswith('Icosagon one-hot input layer')