|
|
@@ -173,7 +173,7 @@ class Model(torch.nn.Module): |
|
|
for x in next_layer_repr ]
|
|
|
for x in next_layer_repr ]
|
|
|
|
|
|
|
|
|
cur_layer_repr = next_layer_repr
|
|
|
cur_layer_repr = next_layer_repr
|
|
|
return next_layer_repr
|
|
|
|
|
|
|
|
|
return cur_layer_repr
|
|
|
|
|
|
|
|
|
def decode(self, last_layer_repr: List[torch.Tensor],
|
|
|
def decode(self, last_layer_repr: List[torch.Tensor],
|
|
|
batch: TrainingBatch) -> torch.Tensor:
|
|
|
batch: TrainingBatch) -> torch.Tensor:
|
|
|
@@ -182,12 +182,27 @@ class Model(torch.nn.Module): |
|
|
vt_col = batch.vertex_type_column
|
|
|
vt_col = batch.vertex_type_column
|
|
|
rel_idx = batch.relation_type_index
|
|
|
rel_idx = batch.relation_type_index
|
|
|
global_interaction = \
|
|
|
global_interaction = \
|
|
|
self.dec_weights['%d-%d-global-interaction'] % (vt_row, vt_col)
|
|
|
|
|
|
|
|
|
self.dec_weights['%d-%d-global-interaction' % (vt_row, vt_col)]
|
|
|
local_variation = \
|
|
|
local_variation = \
|
|
|
self.dec_weights['%d-%d-local-variation-%d'] % (vt_row, vt_col, rel_idx)
|
|
|
|
|
|
|
|
|
self.dec_weights['%d-%d-local-variation-%d' % (vt_row, vt_col, rel_idx)]
|
|
|
|
|
|
|
|
|
in_row = dropout(last_layer_repr[vt_row], self.keep_prob)
|
|
|
|
|
|
in_col = dropout(last_layer_repr[vt_col], self.keep_prob)
|
|
|
|
|
|
|
|
|
in_row = last_layer_repr[vt_row]
|
|
|
|
|
|
in_col = last_layer_repr[vt_col]
|
|
|
|
|
|
|
|
|
|
|
|
if in_row.is_sparse or in_col.is_sparse:
|
|
|
|
|
|
raise ValueError('Inputs to Model.decode() must be dense')
|
|
|
|
|
|
|
|
|
|
|
|
in_row = in_row[batch.edges[:, 0]]
|
|
|
|
|
|
in_col = in_col[batch.edges[:, 1]]
|
|
|
|
|
|
|
|
|
|
|
|
in_row = dropout(in_row, self.keep_prob)
|
|
|
|
|
|
in_col = dropout(in_col, self.keep_prob)
|
|
|
|
|
|
|
|
|
|
|
|
# in_row = in_row.to_dense()
|
|
|
|
|
|
# in_col = in_col.to_dense()
|
|
|
|
|
|
|
|
|
|
|
|
print('in_row.is_sparse:', in_row.is_sparse)
|
|
|
|
|
|
print('in_col.is_sparse:', in_col.is_sparse)
|
|
|
|
|
|
|
|
|
x = torch.mm(in_row, local_variation)
|
|
|
x = torch.mm(in_row, local_variation)
|
|
|
x = torch.mm(x, global_interaction)
|
|
|
x = torch.mm(x, global_interaction)
|
|
|
@@ -197,7 +212,7 @@ class Model(torch.nn.Module): |
|
|
x = torch.flatten(x)
|
|
|
x = torch.flatten(x)
|
|
|
|
|
|
|
|
|
x = self.dec_activation(x)
|
|
|
x = self.dec_activation(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|