|
@@ -203,17 +203,25 @@ def test_decode_layer_05(): |
|
|
repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
|
|
|
repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
|
|
|
repr_dec_expect = repr_dec_expect.view(10, 10)
|
|
|
repr_dec_expect = repr_dec_expect.view(10, 10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
|
|
|
repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
|
|
|
# repr_dec = torch.flatten(repr_dec)
|
|
|
# repr_dec = torch.flatten(repr_dec)
|
|
|
# repr_dec -= torch.eye(10)
|
|
|
# repr_dec -= torch.eye(10)
|
|
|
#repr_dec_expect = torch.zeros((10, 10))
|
|
|
|
|
|
#x = prep_d.relation_families[0].relation_types[0].edges_pos.train
|
|
|
|
|
|
#repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
|
|
|
|
|
|
#x = prep_d.relation_families[0].relation_types[0].edges_neg.train
|
|
|
|
|
|
#repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
|
|
|
|
|
|
|
|
|
assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
|
|
|
|
|
|
|
|
|
|
|
|
repr_dec_expect = torch.zeros((10, 10))
|
|
|
|
|
|
x = prep_d.relation_families[0].relation_types[0].edges_pos.train
|
|
|
|
|
|
repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
|
|
|
|
|
|
x = prep_d.relation_families[0].relation_types[0].edges_neg.train
|
|
|
|
|
|
repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
|
|
|
print(repr_dec)
|
|
|
print(repr_dec)
|
|
|
print(repr_dec_expect)
|
|
|
print(repr_dec_expect)
|
|
|
|
|
|
|
|
|
|
|
|
repr_dec = torch.zeros((10, 10))
|
|
|
|
|
|
x = prep_d.relation_families[0].relation_types[0].edges_pos.train
|
|
|
|
|
|
repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
|
|
|
|
|
|
x = prep_d.relation_families[0].relation_types[0].edges_neg.train
|
|
|
|
|
|
repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
|
|
|
|
|
|
|
|
|
assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
|
|
|
assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
|
|
|
|
|
|
|
|
|
#print(prep_rel.edges_pos.train)
|
|
|
#print(prep_rel.edges_pos.train)
|
|
|