|  |  | @@ -1,4 +1,5 @@ | 
		
	
		
			
			|  |  |  | from icosagon.data import Data | 
		
	
		
			
			|  |  |  | from icosagon.data import Data, \ | 
		
	
		
			
			|  |  |  | _equal | 
		
	
		
			
			|  |  |  | from icosagon.trainprep import prepare_training, \ | 
		
	
		
			
			|  |  |  | TrainValTest | 
		
	
		
			
			|  |  |  | from icosagon.model import Model | 
		
	
	
		
			
				|  |  | @@ -78,6 +79,56 @@ def test_train_loop_03(): | 
		
	
		
			
			|  |  |  | loop.run_epoch() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_train_loop_04(): | 
		
	
		
			
			|  |  |  | adj_mat = torch.rand(10, 10).round() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 10) | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Rel', adj_mat) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | m = Model(prep_d) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | old_values = [] | 
		
	
		
			
			|  |  |  | for prm in m.parameters(): | 
		
	
		
			
			|  |  |  | old_values.append(prm.clone().detach()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loop = TrainLoop(m) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loop.run_epoch() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i, prm in enumerate(m.parameters()): | 
		
	
		
			
			|  |  |  | assert not prm.requires_grad or \ | 
		
	
		
			
			|  |  |  | not torch.all(_equal(prm, old_values[i])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_train_loop_05(): | 
		
	
		
			
			|  |  |  | adj_mat = torch.rand(10, 10).round().to_sparse() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 10) | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Rel', adj_mat) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | prep_d = prepare_training(d, TrainValTest(.8, .1, .1)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | m = Model(prep_d) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | old_values = [] | 
		
	
		
			
			|  |  |  | for prm in m.parameters(): | 
		
	
		
			
			|  |  |  | old_values.append(prm.clone().detach()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loop = TrainLoop(m) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | loop.run_epoch() | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | for i, prm in enumerate(m.parameters()): | 
		
	
		
			
			|  |  |  | assert not prm.requires_grad or \ | 
		
	
		
			
			|  |  |  | not torch.all(_equal(prm, old_values[i])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_timing_01(): | 
		
	
		
			
			|  |  |  | adj_mat = (torch.rand(2000, 2000) < .001).to(torch.float32).to_sparse() | 
		
	
		
			
			|  |  |  | rep = torch.eye(2000).requires_grad_(True) | 
		
	
	
		
			
				|  |  | 
 |