|
@@ -228,6 +228,8 @@ class SparseMultiDGCA(torch.nn.Module): |
|
|
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) |
|
|
self.sparse_dgca.append(SparseDropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) |
|
|
|
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: |
|
|
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
|
|
if not isinstance(x, list): |
|
|
|
|
|
raise ValueError('x must be a list of tensors') |
|
|
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) |
|
|
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) |
|
|
for i, f in enumerate(self.sparse_dgca): |
|
|
for i, f in enumerate(self.sparse_dgca): |
|
|
out += f(x[i]) |
|
|
out += f(x[i]) |
|
@@ -273,12 +275,26 @@ class MultiDGCA(torch.nn.Module): |
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
activation: Callable[[torch.Tensor], torch.Tensor]=torch.nn.functional.relu, |
|
|
**kwargs) -> None: |
|
|
**kwargs) -> None: |
|
|
super().__init__(**kwargs) |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.output_dim = output_dim |
|
|
self.output_dim = output_dim |
|
|
self.dgca = [ DropoutGraphConvActivation(input_dim, output_dim, adj_mat, keep_prob, activation) for adj_mat in adjacency_matrices ] |
|
|
|
|
|
|
|
|
self.adjacency_matrices = adjacency_matrices |
|
|
|
|
|
self.keep_prob = keep_prob |
|
|
|
|
|
self.activation = activation |
|
|
|
|
|
self.dgca = None |
|
|
|
|
|
self.build() |
|
|
|
|
|
|
|
|
|
|
|
def build(self): |
|
|
|
|
|
if len(self.input_dim) != len(self.adjacency_matrices): |
|
|
|
|
|
raise ValueError('input_dim must have the same length as adjacency_matrices') |
|
|
|
|
|
self.dgca = [] |
|
|
|
|
|
for input_dim, adj_mat in zip(self.input_dim, self.adjacency_matrices): |
|
|
|
|
|
self.dgca.append(DropoutGraphConvActivation(input_dim, self.output_dim, adj_mat, self.keep_prob, self.activation)) |
|
|
|
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: |
|
|
out = torch.zeros(len(x), self.output_dim, dtype=x.dtype) |
|
|
|
|
|
for f in self.dgca: |
|
|
|
|
|
out += f(x) |
|
|
|
|
|
|
|
|
if not isinstance(x, list): |
|
|
|
|
|
raise ValueError('x must be a list of tensors') |
|
|
|
|
|
out = torch.zeros(len(x[0]), self.output_dim, dtype=x[0].dtype) |
|
|
|
|
|
for i, f in enumerate(self.dgca): |
|
|
|
|
|
out += f(x[i]) |
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1) |
|
|
out = torch.nn.functional.normalize(out, p=2, dim=1) |
|
|
return out |
|
|
return out |