|
@@ -20,7 +20,6 @@ class GraphConv(torch.nn.Module): |
|
|
self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
|
|
|
self.weight = torch.nn.Parameter(init_glorot(in_channels, out_channels))
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
self.adjacency_matrix = adjacency_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
x = torch.sparse.mm(x, self.weight) \
|
|
|
x = torch.sparse.mm(x, self.weight) \
|
|
|
if x.is_sparse \
|
|
|
if x.is_sparse \
|
|
|