|
@@ -81,7 +81,7 @@ class FastGraphConv(torch.nn.Module): |
|
|
in_channels: List[int],
|
|
|
in_channels: List[int],
|
|
|
out_channels: List[int],
|
|
|
out_channels: List[int],
|
|
|
data: Union[Data, PreparedData],
|
|
|
data: Union[Data, PreparedData],
|
|
|
relation_family: Union[RelationFamily, PreparedRelationFamily]
|
|
|
|
|
|
|
|
|
relation_family: Union[RelationFamily, PreparedRelationFamily],
|
|
|
keep_prob: float = 1.,
|
|
|
keep_prob: float = 1.,
|
|
|
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
**kwargs) -> None:
|
|
|
**kwargs) -> None:
|
|
|