IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Browse Source

Add test_model_03().

master
Stanislaw Adaszewski 3 years ago
parent
commit
cfb3d308a0
4 changed files with 8 additions and 4 deletions
  1. +1
    -1
      src/icosagon/convolve.py
  2. +1
    -1
      src/icosagon/declayer.py
  3. +1
    -1
      src/icosagon/trainprep.py
  4. +5
    -1
      tests/icosagon/test_model.py

+ 1
- 1
src/icosagon/convolve.py View File

@@ -16,7 +16,7 @@ class GraphConv(torch.nn.Module):
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
self.adjacency_matrix = adjacency_matrix


+ 1
- 1
src/icosagon/declayer.py View File

@@ -68,7 +68,7 @@ class DecodeLayer(torch.nn.Module):
self.build()
def build(self) -> None:
self.decoders = []
self.decoders = torch.nn.ModuleList()
for fam in self.data.relation_families:
dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
self.keep_prob, self.activation)


+ 1
- 1
src/icosagon/trainprep.py View File

@@ -111,7 +111,7 @@ def prepare_adj_mat(adj_mat: torch.Tensor,
edges_neg = train_val_test_split_edges(edges_neg, ratios)
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
values=torch.ones(len(edges_pos.train), shape=adj_mat.shape, dtype=adj_mat.dtype))
values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype)
return adj_mat_train, edges_pos, edges_neg


+ 5
- 1
tests/icosagon/test_model.py View File

@@ -81,4 +81,8 @@ def test_model_03():
assert isinstance(state_dict, dict)
# print(state_dict['param_groups'])
# print(list(m.seq.parameters()))
print(list(m.seq[1].parameters()))
assert len(list(m.seq[0].parameters())) == 1
assert len(list(m.seq[1].parameters())) == 1
assert len(list(m.seq[2].parameters())) == 1
assert len(list(m.seq[3].parameters())) == 2
# print(list(m.seq[1].parameters()))

Loading…
Cancel
Save