from icosagon.convolve import GraphConv, \ DropoutGraphConvActivation import torch from icosagon.dropout import dropout def _test_graph_conv_01(use_sparse: bool): adj_mat = torch.rand((10, 20)) adj_mat[adj_mat < .5] = 0 adj_mat = torch.ceil(adj_mat) node_reprs = torch.eye(20) graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) graph_conv.weight = torch.nn.Parameter(torch.eye(20)) res = graph_conv(node_reprs) assert torch.all(res == adj_mat) def _test_graph_conv_02(use_sparse: bool): adj_mat = torch.rand((10, 20)) adj_mat[adj_mat < .5] = 0 adj_mat = torch.ceil(adj_mat) node_reprs = torch.eye(20) graph_conv = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) graph_conv.weight = torch.nn.Parameter(torch.eye(20) * 2) res = graph_conv(node_reprs) assert torch.all(res == adj_mat * 2) def _test_graph_conv_03(use_sparse: bool): adj_mat = torch.tensor([ [1, 0, 1, 0, 1, 0], # [1, 0, 0] [1, 0, 1, 0, 0, 1], # [1, 0, 0] [1, 1, 0, 1, 0, 0], # [0, 1, 0] [0, 0, 0, 1, 0, 1], # [0, 1, 0] [1, 1, 1, 1, 1, 1], # [0, 0, 1] [0, 0, 0, 1, 1, 1] # [0, 0, 1] ], dtype=torch.float32) expect = torch.tensor([ [1, 1, 1], [1, 1, 1], [2, 1, 0], [0, 1, 1], [2, 2, 2], [0, 1, 2] ], dtype=torch.float32) node_reprs = torch.eye(6) graph_conv = GraphConv(6, 3, adj_mat.to_sparse() \ if use_sparse else adj_mat) graph_conv.weight = torch.nn.Parameter(torch.tensor([ [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1] ], dtype=torch.float32)) res = graph_conv(node_reprs) assert torch.all(res == expect) def test_graph_conv_dense_01(): _test_graph_conv_01(use_sparse=False) def test_graph_conv_dense_02(): _test_graph_conv_02(use_sparse=False) def test_graph_conv_dense_03(): _test_graph_conv_03(use_sparse=False) def test_graph_conv_sparse_01(): _test_graph_conv_01(use_sparse=True) def test_graph_conv_sparse_02(): _test_graph_conv_02(use_sparse=True) def test_graph_conv_sparse_03(): _test_graph_conv_03(use_sparse=True) def _test_dropout_graph_conv_activation_01(use_sparse: bool): adj_mat = torch.rand((10, 20)) adj_mat[adj_mat < .5] = 0 adj_mat = torch.ceil(adj_mat) node_reprs = torch.eye(20) conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat, keep_prob=1., activation=lambda x: x) conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) conv_2.weight = conv_1.graph_conv.weight res_1 = conv_1(node_reprs) res_2 = conv_2(node_reprs) print('res_1:', res_1.detach().cpu().numpy()) print('res_2:', res_2.detach().cpu().numpy()) assert torch.all(res_1 == res_2) def _test_dropout_graph_conv_activation_02(use_sparse: bool): adj_mat = torch.rand((10, 20)) adj_mat[adj_mat < .5] = 0 adj_mat = torch.ceil(adj_mat) node_reprs = torch.eye(20) conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat, keep_prob=1., activation=lambda x: x * 2) conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) conv_2.weight = conv_1.graph_conv.weight res_1 = conv_1(node_reprs) res_2 = conv_2(node_reprs) print('res_1:', res_1.detach().cpu().numpy()) print('res_2:', res_2.detach().cpu().numpy()) assert torch.all(res_1 == res_2 * 2) def _test_dropout_graph_conv_activation_03(use_sparse: bool): adj_mat = torch.rand((10, 20)) adj_mat[adj_mat < .5] = 0 adj_mat = torch.ceil(adj_mat) node_reprs = torch.eye(20) conv_1 = DropoutGraphConvActivation(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat, keep_prob=.5, activation=lambda x: x) conv_2 = GraphConv(20, 20, adj_mat.to_sparse() \ if use_sparse else adj_mat) conv_2.weight = conv_1.graph_conv.weight torch.random.manual_seed(0) res_1 = conv_1(node_reprs) torch.random.manual_seed(0) res_2 = conv_2(dropout(node_reprs, 0.5)) print('res_1:', res_1.detach().cpu().numpy()) print('res_2:', res_2.detach().cpu().numpy()) assert torch.all(res_1 == res_2) def test_dropout_graph_conv_activation_dense_01(): _test_dropout_graph_conv_activation_01(False) def test_dropout_graph_conv_activation_sparse_01(): _test_dropout_graph_conv_activation_01(True) def test_dropout_graph_conv_activation_dense_02(): _test_dropout_graph_conv_activation_02(False) def test_dropout_graph_conv_activation_sparse_02(): _test_dropout_graph_conv_activation_02(True) def test_dropout_graph_conv_activation_dense_03(): _test_dropout_graph_conv_activation_03(False) def test_dropout_graph_conv_activation_sparse_03(): _test_dropout_graph_conv_activation_03(True)