| @@ -203,17 +203,25 @@ def test_decode_layer_05(): | |||||
| repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0) | repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0) | ||||
| repr_dec_expect = repr_dec_expect.view(10, 10) | repr_dec_expect = repr_dec_expect.view(10, 10) | ||||
| repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1)) | repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1)) | ||||
| # repr_dec = torch.flatten(repr_dec) | # repr_dec = torch.flatten(repr_dec) | ||||
| # repr_dec -= torch.eye(10) | # repr_dec -= torch.eye(10) | ||||
| #repr_dec_expect = torch.zeros((10, 10)) | |||||
| #x = prep_d.relation_families[0].relation_types[0].edges_pos.train | |||||
| #repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train | |||||
| #x = prep_d.relation_families[0].relation_types[0].edges_neg.train | |||||
| #repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train | |||||
| assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001) | |||||
| repr_dec_expect = torch.zeros((10, 10)) | |||||
| x = prep_d.relation_families[0].relation_types[0].edges_pos.train | |||||
| repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train | |||||
| x = prep_d.relation_families[0].relation_types[0].edges_neg.train | |||||
| repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train | |||||
| print(repr_dec) | print(repr_dec) | ||||
| print(repr_dec_expect) | print(repr_dec_expect) | ||||
| repr_dec = torch.zeros((10, 10)) | |||||
| x = prep_d.relation_families[0].relation_types[0].edges_pos.train | |||||
| repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0) | |||||
| x = prep_d.relation_families[0].relation_types[0].edges_neg.train | |||||
| repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0) | |||||
| assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001) | assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001) | ||||
| #print(prep_rel.edges_pos.train) | #print(prep_rel.edges_pos.train) | ||||