#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import random
from typing import List, Optional
from torch import Tensor
from cvnets.layers.base_layer import BaseLayer
from utils.math_utils import bound_fn
[docs]class RandomApply(BaseLayer):
"""
This layer randomly applies a list of modules during training.
Args:
module_list (List): List of modules
keep_p (Optional[float]): Keep P modules from the list during training. Default: 0.8 (or 80%)
"""
[docs] def __init__(
self, module_list: List, keep_p: Optional[float] = 0.8, *args, **kwargs
) -> None:
super().__init__()
n_modules = len(module_list)
self.module_list = module_list
self.module_indexes = [i for i in range(1, n_modules)]
k = int(round(n_modules * keep_p))
self.keep_k = bound_fn(min_val=1, max_val=n_modules, value=k)
[docs] def forward(self, x: Tensor) -> Tensor:
if self.training:
indexes = [0] + sorted(random.sample(self.module_indexes, k=self.keep_k))
for idx in indexes:
x = self.module_list[idx](x)
else:
for layer in self.module_list:
x = layer(x)
return x
def __repr__(self):
format_string = "{}(apply_k (N={})={}, ".format(
self.__class__.__name__, len(self.module_list), self.keep_k
)
for layer in self.module_list:
format_string += "\n\t {}".format(layer)
format_string += "\n)"
return format_string