|  |  | @@ -86,3 +86,19 @@ def test_model_03(): | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[2].parameters())) == 1 | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[3].parameters())) == 2 | 
		
	
		
			
			|  |  |  | # print(list(m.seq[1].parameters())) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_model_04(): | 
		
	
		
			
			|  |  |  | d = Data() | 
		
	
		
			
			|  |  |  | d.add_node_type('Dummy', 10) | 
		
	
		
			
			|  |  |  | fam = d.add_relation_family('Dummy-Dummy', 0, 0, False) | 
		
	
		
			
			|  |  |  | mat = torch.rand(10, 10).round().to_sparse() | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Rel 1', mat) | 
		
	
		
			
			|  |  |  | fam.add_relation_type('Dummy Rel 2', mat.clone()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | m = Model(d, ratios=TrainValTest(1., 0., 0.)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[0].parameters())) == 1 | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[1].parameters())) == 2 | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[2].parameters())) == 2 | 
		
	
		
			
			|  |  |  | assert len(list(m.seq[3].parameters())) == 3 |