IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Add support for multiple input_dim in MultiDGCA.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
2e5e1d69ac
共有 2 个文件被更改,包括 23 次插入7 次删除
  1. +20
    -4
      src/decagon_pytorch/convolve.py
  2. +3
    -3
      tests/decagon_pytorch/test_convolve.py

+ 20
- 4
src/decagon_pytorch/convolve.py 查看文件

@@ -228,6 +228,8 @@ class SparseMultiDGCA(torch.nn.Module):
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))

def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.sparse_dgca):
out += f(x[i])
@@ -273,12 +275,26 @@ class MultiDGCA(torch.nn.Module):
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu,
**kwargs) -> None:
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ]
self.adjacency_matrices = adjacency_matrices
self.keep_prob = keep_prob
self.activation = activation
self.dgca = None
self.build()

def build(self):
if len(self.input_dim) != len(self.adjacency_matrices):
raise ValueError('input_dim must have the same length as adjacency_matrices')
self.dgca = []
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices):
self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation))

def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype)
for f in self.dgca:
out += f(x)
if not isinstance(x, list):
raise ValueError('x must be a list of tensors')
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype)
for i, f in enumerate(self.dgca):
out += f(x[i])
out = torch.nn.functional.normalize(out, p=2, dim=1)
return out

+ 3
- 3
tests/decagon_pytorch/test_convolve.py 查看文件

@@ -236,10 +236,10 @@ def test_multi_dgca():
assert np.all(adjacency_matrices[i].numpy() == adjacency_matrices_sparse[i].to_dense().numpy())
torch.random.manual_seed(0)
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,]*len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob)
multi_sparse = decagon_pytorch.convolve.SparseMultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices_sparse, keep_prob=keep_prob)
torch.random.manual_seed(0)
multi = decagon_pytorch.convolve.MultiDGCA(10, 10, adjacency_matrices, keep_prob=keep_prob)
multi = decagon_pytorch.convolve.MultiDGCA([10,] * len(adjacency_matrices), 10, adjacency_matrices, keep_prob=keep_prob)
print('len(adjacency_matrices):', len(adjacency_matrices))
print('len(multi_sparse.sparse_dgca):', len(multi_sparse.sparse_dgca))
@@ -251,6 +251,6 @@ def test_multi_dgca():
# torch.random.manual_seed(0)
latent_sparse = multi_sparse([latent_sparse,] * len(adjacency_matrices))
# torch.random.manual_seed(0)
latent = multi(latent)
latent = multi([latent,] * len(adjacency_matrices))
assert np.all(latent_sparse.detach().numpy() == latent.detach().numpy())

正在加载...
取消
保存