diff --git a/src/icosagon/decode.py b/src/icosagon/decode.py index dccf508..b69ec97 100644 --- a/src/icosagon/decode.py +++ b/src/icosagon/decode.py @@ -26,22 +26,20 @@ class DEDICOMDecoder(torch.nn.Module): for _ in range(num_relation_types) ] - def forward(self, inputs_row, inputs_col): - outputs = [] - for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) - relation = torch.diag(self.local_variation[k]) + relation = torch.diag(self.local_variation[relation_index]) - product1 = torch.mm(inputs_row, relation) - product2 = torch.mm(product1, self.global_interaction) - product3 = torch.mm(product2, relation) - rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - outputs.append(self.activation(rec)) - return outputs + product1 = torch.mm(inputs_row, relation) + product2 = torch.mm(product1, self.global_interaction) + product3 = torch.mm(product2, relation) + rec = torch.bmm(product3.view(product3.shape[0], 1, product3.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) class DistMultDecoder(torch.nn.Module): @@ -60,20 +58,18 @@ class DistMultDecoder(torch.nn.Module): for _ in range(num_relation_types) ] - def forward(self, inputs_row, inputs_col): - outputs = [] - for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + relation = torch.diag(self.relation[relation_index]) - relation = torch.diag(self.relation[k]) + intermediate_product = torch.mm(inputs_row, relation) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) - intermediate_product = torch.mm(inputs_row, relation) - rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - outputs.append(self.activation(rec)) - return outputs + return self.activation(rec) class BilinearDecoder(torch.nn.Module): @@ -92,18 +88,16 @@ class BilinearDecoder(torch.nn.Module): for _ in range(num_relation_types) ] - def forward(self, inputs_row, inputs_col): - outputs = [] - for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) + def forward(self, inputs_row, inputs_col, relation_index): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) - intermediate_product = torch.mm(inputs_row, self.relation[k]) - rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - outputs.append(self.activation(rec)) - return outputs + intermediate_product = torch.mm(inputs_row, self.relation[relation_index]) + rec = torch.bmm(intermediate_product.view(intermediate_product.shape[0], 1, intermediate_product.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) + + return self.activation(rec) class InnerProductDecoder(torch.nn.Module): @@ -118,14 +112,12 @@ class InnerProductDecoder(torch.nn.Module): self.activation = activation - def forward(self, inputs_row, inputs_col): - outputs = [] - for k in range(self.num_relation_types): - inputs_row = dropout(inputs_row, self.keep_prob) - inputs_col = dropout(inputs_col, self.keep_prob) + def forward(self, inputs_row, inputs_col, _): + inputs_row = dropout(inputs_row, self.keep_prob) + inputs_col = dropout(inputs_col, self.keep_prob) + + rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), + inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) + rec = torch.flatten(rec) - rec = torch.bmm(inputs_row.view(inputs_row.shape[0], 1, inputs_row.shape[1]), - inputs_col.view(inputs_col.shape[0], inputs_col.shape[1], 1)) - rec = torch.flatten(rec) - outputs.append(self.activation(rec)) - return outputs + return self.activation(rec) diff --git a/tests/icosagon/test_decode.py b/tests/icosagon/test_decode.py index 001f411..c0ec4c3 100644 --- a/tests/icosagon/test_decode.py +++ b/tests/icosagon/test_decode.py @@ -15,7 +15,7 @@ def test_dedicom_decoder_01(): dec_2.global_interaction = dec_1.global_interaction dec_2.local_variation = dec_1.local_variation - res_1 = dec_1(repr_, repr_) + res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ] res_2 = dec_2(repr_, repr_) assert isinstance(res_1, list) @@ -35,7 +35,7 @@ def test_dist_mult_decoder_01(): activation=torch.sigmoid) dec_2.relation = dec_1.relation - res_1 = dec_1(repr_, repr_) + res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ] res_2 = dec_2(repr_, repr_) assert isinstance(res_1, list) @@ -55,7 +55,7 @@ def test_bilinear_decoder_01(): activation=torch.sigmoid) dec_2.relation = dec_1.relation - res_1 = dec_1(repr_, repr_) + res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ] res_2 = dec_2(repr_, repr_) assert isinstance(res_1, list) @@ -74,7 +74,7 @@ def test_inner_product_decoder_01(): dec_2 = decagon_pytorch.decode.pairwise.InnerProductDecoder(32, 7, drop_prob=0., activation=torch.sigmoid) - res_1 = dec_1(repr_, repr_) + res_1 = [ dec_1(repr_, repr_, k) for k in range(7) ] res_2 = dec_2(repr_, repr_) assert isinstance(res_1, list)