Source code for cvnets.layers.activation.hard_swish

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import Optional

from torch import Tensor, nn
from torch.nn import functional as F

from cvnets.layers.activation import register_act_fn


[docs]@register_act_fn(name="hard_swish") class Hardswish(nn.Hardswish): """ Applies the HardSwish function, as described in the paper `Searching for MobileNetv3 <https://arxiv.org/abs/1905.02244>`_ """
[docs] def __init__(self, inplace: Optional[bool] = False, *args, **kwargs) -> None: super().__init__(inplace=inplace)
[docs] def forward(self, input: Tensor, *args, **kwargs) -> Tensor: if hasattr(F, "hardswish"): return F.hardswish(input, self.inplace) else: x_hard_sig = F.relu(input + 3) / 6 return input * x_hard_sig