@@ -16,7 +16,7 @@ class GraphConv(torch.nn.Module): | |||||
super().__init__(**kwargs) | super().__init__(**kwargs) | ||||
self.in_channels = in_channels | self.in_channels = in_channels | ||||
self.out_channels = out_channels | self.out_channels = out_channels | ||||
self.weight = init_glorot(in_channels, out_channels) | |||||
self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels)) | |||||
self.adjacency_matrix = adjacency_matrix | self.adjacency_matrix = adjacency_matrix | ||||
@@ -68,7 +68,7 @@ class DecodeLayer(torch.nn.Module): | |||||
self.build() | self.build() | ||||
def build(self) -> None: | def build(self) -> None: | ||||
self.decoders = [] | |||||
self.decoders = torch.nn.ModuleList() | |||||
for fam in self.data.relation_families: | for fam in self.data.relation_families: | ||||
dec = fam.decoder_class(self.input_dim, len(fam.relation_types), | dec = fam.decoder_class(self.input_dim, len(fam.relation_types), | ||||
self.keep_prob, self.activation) | self.keep_prob, self.activation) | ||||
@@ -111,7 +111,7 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
edges_neg = train_val_test_split_edges(edges_neg, ratios) | edges_neg = train_val_test_split_edges(edges_neg, ratios) | ||||
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1), | ||||
values=torch.ones(len(edges_pos.train), shape=adj_mat.shape, dtype=adj_mat.dtype)) | |||||
values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype) | |||||
return adj_mat_train, edges_pos, edges_neg | return adj_mat_train, edges_pos, edges_neg | ||||
@@ -81,4 +81,8 @@ def test_model_03(): | |||||
assert isinstance(state_dict, dict) | assert isinstance(state_dict, dict) | ||||
# print(state_dict['param_groups']) | # print(state_dict['param_groups']) | ||||
# print(list(m.seq.parameters())) | # print(list(m.seq.parameters())) | ||||
print(list(m.seq[1].parameters())) | |||||
assert len(list(m.seq[0].parameters())) == 1 | |||||
assert len(list(m.seq[1].parameters())) == 1 | |||||
assert len(list(m.seq[2].parameters())) == 1 | |||||
assert len(list(m.seq[3].parameters())) == 2 | |||||
# print(list(m.seq[1].parameters())) |