| @@ -19,7 +19,7 @@ def index(a, x): | |||
| raise ValueError | |||
| def load_data(): | |||
| def load_data(dev): | |||
| path = '/pstore/data/data_science/ref/decagon' | |||
| df_combo = pd.read_csv(os.path.join(path, 'bio-decagon-combo.csv')) | |||
| df_effcat = pd.read_csv(os.path.join(path, 'bio-decagon-effectcategories.csv')) | |||
| @@ -57,7 +57,8 @@ def load_data(): | |||
| indices = torch.tensor(indices).transpose(0, 1) | |||
| values = torch.ones(len(rows)) | |||
| print('indices.shape:', indices.shape, 'values.shape:', values.shape) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(genes),) * 2, | |||
| device=dev) | |||
| adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 | |||
| print('adj_mat created') | |||
| fam = data.add_relation_family('PPI', 0, 0, True) | |||
| @@ -70,7 +71,8 @@ def load_data(): | |||
| indices = list(zip(rows, cols)) | |||
| indices = torch.tensor(indices).transpose(0, 1) | |||
| values = torch.ones(len(rows)) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes))) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(genes)), | |||
| device=dev) | |||
| fam = data.add_relation_family('Drug-Gene (Target)', 1, 0, True) | |||
| rel = fam.add_relation_type('Drug-Gene (Target)', adj_mat) | |||
| print('OK') | |||
| @@ -86,7 +88,8 @@ def load_data(): | |||
| indices = list(zip(rows, cols)) | |||
| indices = torch.tensor(indices).transpose(0, 1) | |||
| values = torch.ones(len(rows)) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs))) | |||
| adj_mat = torch.sparse_coo_tensor(indices, values, size=(len(drugs), len(drugs)), | |||
| device=dev) | |||
| adj_mat = (adj_mat + adj_mat.transpose(0, 1)) / 2 | |||
| rel = fam.add_relation_type(df['Polypharmacy Side Effect'], adj_mat) | |||
| print() | |||
| @@ -106,10 +109,13 @@ def _wrap(obj, method_name): | |||
| def main(): | |||
| data = load_data() | |||
| dev = torch.device('cuda:0') | |||
| data = load_data(dev) | |||
| prep_d = prepare_training(data, TrainValTest(.8, .1, .1)) | |||
| _wrap(Model, 'build') | |||
| model = Model(prep_d) | |||
| model = model.to(dev) | |||
| # model = torch.nn.DataParallel(model, ['cuda:0', 'cuda:1']) | |||
| _wrap(TrainLoop, 'build') | |||
| _wrap(TrainLoop, 'run_epoch') | |||
| loop = TrainLoop(model, batch_size=1000000) | |||
| @@ -8,6 +8,7 @@ import torch | |||
| from .dropout import dropout | |||
| from .weights import init_glorot | |||
| from typing import List, Callable | |||
| import pdb | |||
| class GraphConv(torch.nn.Module): | |||
| @@ -44,6 +45,7 @@ class DropoutGraphConvActivation(torch.nn.Module): | |||
| self.graph_conv = GraphConv(input_dim, output_dim, adjacency_matrix) | |||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
| # pdb.set_trace() | |||
| x = dropout(x, self.keep_prob) | |||
| x = self.graph_conv(x) | |||
| x = self.activation(x) | |||
| @@ -44,9 +44,11 @@ def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| indices = adj_mat.indices() | |||
| values = adj_mat.values() | |||
| eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype).view(1, -1) | |||
| eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype, | |||
| device=adj_mat.device).view(1, -1) | |||
| eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||
| eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype) | |||
| eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype, | |||
| device=adj_mat.device) | |||
| indices = torch.cat((indices, eye_indices), 1) | |||
| values = torch.cat((values, eye_values), 0) | |||
| @@ -72,7 +74,8 @@ def norm_adj_mat_one_node_type_dense(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| _check_dense(adj_mat) | |||
| _check_square(adj_mat) | |||
| adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype) | |||
| adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype, | |||
| device=adj_mat.device) | |||
| adj_mat = norm_adj_mat_two_node_types_dense(adj_mat) | |||
| return adj_mat | |||
| @@ -96,9 +99,9 @@ def norm_adj_mat_two_node_types_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||
| adj_mat = adj_mat.coalesce() | |||
| indices = adj_mat.indices() | |||
| values = adj_mat.values() | |||
| degrees_row = torch.zeros(adj_mat.shape[0]) | |||
| degrees_row = torch.zeros(adj_mat.shape[0], device=adj_mat.device) | |||
| degrees_row = degrees_row.index_add(0, indices[0], values.to(degrees_row.dtype)) | |||
| degrees_col = torch.zeros(adj_mat.shape[1]) | |||
| degrees_col = torch.zeros(adj_mat.shape[1], device=adj_mat.device) | |||
| degrees_col = degrees_col.index_add(0, indices[1], values.to(degrees_col.dtype)) | |||
| values = values.to(degrees_row.dtype) / torch.sqrt(degrees_row[indices[0]] * degrees_col[indices[1]]) | |||
| adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape) | |||
| @@ -18,8 +18,13 @@ def fixed_unigram_candidate_sampler( | |||
| if isinstance(true_classes, torch.Tensor): | |||
| true_classes = true_classes.detach().cpu().numpy() | |||
| if isinstance(unigrams, torch.Tensor): | |||
| unigrams = unigrams.detach().cpu().numpy() | |||
| if len(true_classes.shape) != 2: | |||
| raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)') | |||
| num_samples = true_classes.shape[0] | |||
| unigrams = np.array(unigrams) | |||
| if distortion != 1.: | |||
| @@ -83,9 +83,11 @@ def train_val_test_split_edges(edges: torch.Tensor, | |||
| def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
| if adj_mat.is_sparse: | |||
| adj_mat = adj_mat.coalesce() | |||
| degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64) | |||
| degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64, | |||
| device=adj_mat.device) | |||
| degrees = degrees.index_add(0, adj_mat.indices()[1], | |||
| torch.ones(adj_mat.indices().shape[1], dtype=torch.int64)) | |||
| torch.ones(adj_mat.indices().shape[1], dtype=torch.int64, | |||
| device=adj_mat.device)) | |||
| edges_pos = adj_mat.indices().transpose(0, 1) | |||
| else: | |||
| degrees = adj_mat.sum(0) | |||
| @@ -102,7 +104,7 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||
| edges_pos, degrees = get_edges_and_degrees(adj_mat) | |||
| neg_neighbors = fixed_unigram_candidate_sampler( | |||
| edges_pos[:, 1].view(-1, 1), degrees, 0.75) | |||
| edges_pos[:, 1].view(-1, 1), degrees, 0.75).to(adj_mat.device) | |||
| print(edges_pos.dtype) | |||
| print(neg_neighbors.dtype) | |||
| edges_neg = torch.cat((edges_pos[:, 0].view(-1, 1), neg_neighbors.view(-1, 1)), 1) | |||
| @@ -111,7 +113,8 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||
| edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||
| adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | |||
| values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype) | |||
| values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype, | |||
| device=adj_mat.device) | |||
| return adj_mat_train, edges_pos, edges_neg | |||
| @@ -4,6 +4,8 @@ from icosagon.trainprep import prepare_training, \ | |||
| from icosagon.model import Model | |||
| from icosagon.trainloop import TrainLoop | |||
| import torch | |||
| import pytest | |||
| import pdb | |||
| def test_train_loop_01(): | |||
| @@ -37,3 +39,32 @@ def test_train_loop_02(): | |||
| loop = TrainLoop(m) | |||
| loop.run_epoch() | |||
| def test_train_loop_03(): | |||
| if torch.cuda.device_count() == 0: | |||
| pytest.skip('CUDA required for this test') | |||
| adj_mat = torch.rand(10, 10).round() | |||
| dev = torch.device('cuda:0') | |||
| adj_mat = adj_mat.to(dev) | |||
| d = Data() | |||
| d.add_node_type('Dummy', 10) | |||
| fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | |||
| fam.add_relation_type('Dummy Rel', adj_mat) | |||
| prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | |||
| # pdb.set_trace() | |||
| m = Model(prep_d) | |||
| m = m.to(dev) | |||
| print(list(m.parameters())) | |||
| for prm in m.parameters(): | |||
| assert prm.device == dev | |||
| loop = TrainLoop(m) | |||
| loop.run_epoch() | |||