IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Переглянути джерело

Ammend test_decode_layer_05.

master
Stanislaw Adaszewski 4 роки тому
джерело
коміт
dc8c51a8e6
1 змінених файлів з 14 додано та 6 видалено
  1. +14
    -6
      tests/icosagon/test_declayer.py

+ 14
- 6
tests/icosagon/test_declayer.py Переглянути файл

@@ -203,17 +203,25 @@ def test_decode_layer_05():
repr_dec_expect = dec(repr_conv[x], repr_conv[y], 0)
repr_dec_expect = repr_dec_expect.view(10, 10)
repr_dec = torch.mm(repr_conv, torch.transpose(repr_conv, 0, 1))
# repr_dec = torch.flatten(repr_dec)
# repr_dec -= torch.eye(10)
#repr_dec_expect = torch.zeros((10, 10))
#x = prep_d.relation_families[0].relation_types[0].edges_pos.train
#repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
#x = prep_d.relation_families[0].relation_types[0].edges_neg.train
#repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
repr_dec_expect = torch.zeros((10, 10))
x = prep_d.relation_families[0].relation_types[0].edges_pos.train
repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_pos.train
x = prep_d.relation_families[0].relation_types[0].edges_neg.train
repr_dec_expect[x[:, 0], x[:, 1]] = pred.relation_families[0].relation_types[0].edges_neg.train
print(repr_dec)
print(repr_dec_expect)
repr_dec = torch.zeros((10, 10))
x = prep_d.relation_families[0].relation_types[0].edges_pos.train
repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
x = prep_d.relation_families[0].relation_types[0].edges_neg.train
repr_dec[x[:, 0], x[:, 1]] = dec(repr_conv[x[:, 0]], repr_conv[x[:, 1]], 0)
assert torch.all(torch.abs(repr_dec - repr_dec_expect) < 0.000001)
#print(prep_rel.edges_pos.train)


Завантаження…
Відмінити
Зберегти