diff --git a/src/icosagon/fastconv.py b/src/icosagon/fastconv.py index 02086c4..6f90eb8 100644 --- a/src/icosagon/fastconv.py +++ b/src/icosagon/fastconv.py @@ -81,7 +81,7 @@ class FastGraphConv(torch.nn.Module): in_channels: List[int], out_channels: List[int], data: Union[Data, PreparedData], - relation_family: Union[RelationFamily, PreparedRelationFamily] + relation_family: Union[RelationFamily, PreparedRelationFamily], keep_prob: float = 1., acivation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x, **kwargs) -> None: