|  |  | @@ -1,5 +1,7 @@ | 
		
	
		
			
			|  |  |  | from triacontagon.split import split_adj_mat | 
		
	
		
			
			|  |  |  | from triacontagon.split import split_adj_mat, \ | 
		
	
		
			
			|  |  |  | split_edge_type | 
		
	
		
			
			|  |  |  | from triacontagon.util import _equal | 
		
	
		
			
			|  |  |  | from triacontagon.data import EdgeType | 
		
	
		
			
			|  |  |  | import torch | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
	
		
			
				|  |  | @@ -39,3 +41,84 @@ def test_split_adj_mat_03(): | 
		
	
		
			
			|  |  |  | print('a:', a.to_dense(), 'b:', b.to_dense(), 'c:', c.to_dense()) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(a+b+c, adj_mat)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_split_edge_type_01(): | 
		
	
		
			
			|  |  |  | et = EdgeType('Dummy', 0, 1, [ | 
		
	
		
			
			|  |  |  | torch.tensor([ | 
		
	
		
			
			|  |  |  | [0, 1, 0, 0, 0], | 
		
	
		
			
			|  |  |  | [0, 0, 1, 0, 1], | 
		
	
		
			
			|  |  |  | [1, 0, 0, 0, 1], | 
		
	
		
			
			|  |  |  | [0, 1, 0, 1, 0] | 
		
	
		
			
			|  |  |  | ]).to_sparse() | 
		
	
		
			
			|  |  |  | ], None, None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res = split_edge_type(et, (1.,)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(et.adjacency_matrices[0], | 
		
	
		
			
			|  |  |  | res[0].adjacency_matrices[0])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_split_edge_type_02(): | 
		
	
		
			
			|  |  |  | et = EdgeType('Dummy', 0, 1, [ | 
		
	
		
			
			|  |  |  | torch.tensor([ | 
		
	
		
			
			|  |  |  | [0, 1, 0, 0, 0], | 
		
	
		
			
			|  |  |  | [0, 0, 1, 0, 1], | 
		
	
		
			
			|  |  |  | [1, 0, 0, 0, 1], | 
		
	
		
			
			|  |  |  | [0, 1, 0, 1, 0] | 
		
	
		
			
			|  |  |  | ]).to_sparse() | 
		
	
		
			
			|  |  |  | ], None, None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res = split_edge_type(et, (.5, .5)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(et.adjacency_matrices[0], | 
		
	
		
			
			|  |  |  | res[0].adjacency_matrices[0] + \ | 
		
	
		
			
			|  |  |  | res[1].adjacency_matrices[0])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_split_edge_type_03(): | 
		
	
		
			
			|  |  |  | et = EdgeType('Dummy', 0, 1, [ | 
		
	
		
			
			|  |  |  | torch.tensor([ | 
		
	
		
			
			|  |  |  | [0, 1, 0, 0, 0], | 
		
	
		
			
			|  |  |  | [0, 0, 1, 0, 1], | 
		
	
		
			
			|  |  |  | [1, 0, 0, 0, 1], | 
		
	
		
			
			|  |  |  | [0, 1, 0, 1, 0] | 
		
	
		
			
			|  |  |  | ]).to_sparse() | 
		
	
		
			
			|  |  |  | ], None, None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res = split_edge_type(et, (.4, .4, .2)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(et.adjacency_matrices[0], | 
		
	
		
			
			|  |  |  | res[0].adjacency_matrices[0] + \ | 
		
	
		
			
			|  |  |  | res[1].adjacency_matrices[0] + \ | 
		
	
		
			
			|  |  |  | res[2].adjacency_matrices[0])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | def test_split_edge_type_04(): | 
		
	
		
			
			|  |  |  | et = EdgeType('Dummy', 0, 1, [ | 
		
	
		
			
			|  |  |  | torch.tensor([ | 
		
	
		
			
			|  |  |  | [0, 1, 0, 0, 0], | 
		
	
		
			
			|  |  |  | [0, 0, 1, 0, 1], | 
		
	
		
			
			|  |  |  | [1, 0, 0, 0, 1], | 
		
	
		
			
			|  |  |  | [0, 1, 0, 1, 0] | 
		
	
		
			
			|  |  |  | ]).to_sparse(), | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | torch.tensor([ | 
		
	
		
			
			|  |  |  | [1, 0, 0, 0, 0], | 
		
	
		
			
			|  |  |  | [0, 1, 0, 1, 0], | 
		
	
		
			
			|  |  |  | [0, 0, 1, 1, 0], | 
		
	
		
			
			|  |  |  | [1, 0, 1, 0, 0] | 
		
	
		
			
			|  |  |  | ]).to_sparse() | 
		
	
		
			
			|  |  |  | ], None, None) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | res = split_edge_type(et, (.4, .4, .2)) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(et.adjacency_matrices[0], | 
		
	
		
			
			|  |  |  | res[0].adjacency_matrices[0] + \ | 
		
	
		
			
			|  |  |  | res[1].adjacency_matrices[0] + \ | 
		
	
		
			
			|  |  |  | res[2].adjacency_matrices[0])) | 
		
	
		
			
			|  |  |  |  | 
		
	
		
			
			|  |  |  | assert torch.all(_equal(et.adjacency_matrices[1], | 
		
	
		
			
			|  |  |  | res[0].adjacency_matrices[1] + \ | 
		
	
		
			
			|  |  |  | res[1].adjacency_matrices[1] + \ | 
		
	
		
			
			|  |  |  | res[2].adjacency_matrices[1])) |