diff --git a/tests/icosagon/test_declayer.py b/tests/icosagon/test_declayer.py index 0a3c617..216b1ed 100644 --- a/tests/icosagon/test_declayer.py +++ b/tests/icosagon/test_declayer.py @@ -203,17 +203,25 @@ def test_decode_layer_05(): repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0) repr_dec_expect = repr_dec_expect.view(10, 10) - repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1)) # repr_dec = torch.flatten(repr_dec) # repr_dec -= torch.eye(10) - #repr_dec_expect = torch.zeros((10, 10)) - #x = prep_d.relation_families[0].relation_types[0].edges_pos.train - #repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train - #x = prep_d.relation_families[0].relation_types[0].edges_neg.train - #repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train + assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001) + + repr_dec_expect = torch.zeros((10, 10)) + x = prep_d.relation_families[0].relation_types[0].edges_pos.train + repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train + x = prep_d.relation_families[0].relation_types[0].edges_neg.train + repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train print(repr_dec) print(repr_dec_expect) + + repr_dec = torch.zeros((10, 10)) + x = prep_d.relation_families[0].relation_types[0].edges_pos.train + repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0) + x = prep_d.relation_families[0].relation_types[0].edges_neg.train + repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0) + assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001) #print(prep_rel.edges_pos.train)