IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
소스 검색

Add test_model_03().

master
Stanislaw Adaszewski 3 년 전
부모
커밋
cfb3d308a0
4개의 변경된 파일8개의 추가작업 그리고 4개의 파일을 삭제
  1. +1
    -1
      src/icosagon/convolve.py
  2. +1
    -1
      src/icosagon/declayer.py
  3. +1
    -1
      src/icosagon/trainprep.py
  4. +5
    -1
      tests/icosagon/test_model.py

+ 1
- 1
src/icosagon/convolve.py 파일 보기

@@ -16,7 +16,7 @@ class GraphConv(torch.nn.Module):
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = init_glorot(in_channels, out_channels)
self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
self.adjacency_matrix = adjacency_matrix


+ 1
- 1
src/icosagon/declayer.py 파일 보기

@@ -68,7 +68,7 @@ class DecodeLayer(torch.nn.Module):
self.build()
def build(self) -> None:
self.decoders = []
self.decoders = torch.nn.ModuleList()
for fam in self.data.relation_families:
dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
self.keep_prob, self.activation)


+ 1
- 1
src/icosagon/trainprep.py 파일 보기

@@ -111,7 +111,7 @@ def prepare_adj_mat(adj_mat: torch.Tensor,
edges_neg = train_val_test_split_edges(edges_neg, ratios)
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos.train.transpose(0, 1),
values=torch.ones(len(edges_pos.train), shape=adj_mat.shape, dtype=adj_mat.dtype))
values=torch.ones(len(edges_pos.train)), size=adj_mat.shape, dtype=adj_mat.dtype)
return adj_mat_train, edges_pos, edges_neg


+ 5
- 1
tests/icosagon/test_model.py 파일 보기

@@ -81,4 +81,8 @@ def test_model_03():
assert isinstance(state_dict, dict)
# print(state_dict['param_groups'])
# print(list(m.seq.parameters()))
print(list(m.seq[1].parameters()))
assert len(list(m.seq[0].parameters())) == 1
assert len(list(m.seq[1].parameters())) == 1
assert len(list(m.seq[2].parameters())) == 1
assert len(list(m.seq[3].parameters())) == 2
# print(list(m.seq[1].parameters()))

불러오는 중...
취소
저장