IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Преглед на файлове

Modify decoders to handle one relation at a time.

master
Stanislaw Adaszewski преди 4 години
родител
ревизия
53cd182e6b
променени са 2 файла, в които са добавени 43 реда и са изтрити 51 реда
  1. +39
    -47
      src/icosagon/decode.py
  2. +4
    -4
      tests/icosagon/test_decode.py

+ 39
- 47
src/icosagon/decode.py Целия файл

@@ -26,22 +26,20 @@ class DEDICOMDecoder(torch.nn.Module):
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
relation = torch.diag(self.local_variation[k])
relation = torch.diag(self.local_variation[relation_index])
product1 = torch.mm(inputs_row, relation)
product2 = torch.mm(product1, self.global_interaction)
product3 = torch.mm(product2, relation)
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
product1 = torch.mm(inputs_row, relation)
product2 = torch.mm(product1, self.global_interaction)
product3 = torch.mm(product2, relation)
rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
class DistMultDecoder(torch.nn.Module):
@@ -60,20 +58,18 @@ class DistMultDecoder(torch.nn.Module):
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
relation = torch.diag(self.relation[relation_index])
relation = torch.diag(self.relation[k])
intermediate_product = torch.mm(inputs_row, relation)
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
intermediate_product = torch.mm(inputs_row, relation)
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
return self.activation(rec)
class BilinearDecoder(torch.nn.Module):
@@ -92,18 +88,16 @@ class BilinearDecoder(torch.nn.Module):
for _ in range(num_relation_types)
]
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
def forward(self, inputs_row, inputs_col, relation_index):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
intermediate_product = torch.mm(inputs_row, self.relation[k])
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
intermediate_product = torch.mm(inputs_row, self.relation[relation_index])
rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
return self.activation(rec)
class InnerProductDecoder(torch.nn.Module):
@@ -118,14 +112,12 @@ class InnerProductDecoder(torch.nn.Module):
self.activation = activation
def forward(self, inputs_row, inputs_col):
outputs = []
for k in range(self.num_relation_types):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
def forward(self, inputs_row, inputs_col, _):
inputs_row = dropout(inputs_row, self.keep_prob)
inputs_col = dropout(inputs_col, self.keep_prob)
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]),
inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1))
rec = torch.flatten(rec)
outputs.append(self.activation(rec))
return outputs
return self.activation(rec)

+ 4
- 4
tests/icosagon/test_decode.py Целия файл

@@ -15,7 +15,7 @@ def test_dedicom_decoder_01():
dec_2.global_interaction = dec_1.global_interaction
dec_2.local_variation = dec_1.local_variation
res_1 = dec_1(repr_, repr_)
res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
@@ -35,7 +35,7 @@ def test_dist_mult_decoder_01():
activation=torch.sigmoid)
dec_2.relation = dec_1.relation
res_1 = dec_1(repr_, repr_)
res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
@@ -55,7 +55,7 @@ def test_bilinear_decoder_01():
activation=torch.sigmoid)
dec_2.relation = dec_1.relation
res_1 = dec_1(repr_, repr_)
res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)
@@ -74,7 +74,7 @@ def test_inner_product_decoder_01():
dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0.,
activation=torch.sigmoid)
res_1 = dec_1(repr_, repr_)
res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ]
res_2 = dec_2(repr_, repr_)
assert isinstance(res_1, list)


Loading…
Отказ
Запис