#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import os
import re
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from utils import logger
from utils.common_utils import unwrap_model_fn
from utils.ddp_utils import is_master, is_start_rank_node
[docs]def clean_strip(
obj: Union[str, List[str]], sep: Optional[str] = ",", strip: bool = True
) -> List[str]:
# Allowing list of strings as input as well as comma-separated strings
if isinstance(obj, list):
strings = obj
else:
strings = obj.split(sep)
if strip:
strings = [x.strip() for x in strings]
strings = [x for x in strings if x]
return strings
[docs]def load_pretrained_model(
model: torch.nn.Module, wt_loc: str, opts: argparse.Namespace, *args, **kwargs
) -> torch.nn.Module:
"""Helper function to load pre-trained weights.
Args:
model: Model whose weights will be loaded.
wt_loc: Path to file to load state_dict from.
opts: Input arguments.
Returns:
The model loaded with the given weights.
"""
if not os.path.isfile(wt_loc):
logger.error("Pretrained file is not found here: {}".format(wt_loc))
wts = torch.load(wt_loc, map_location="cpu")
is_master_node = is_start_rank_node(opts)
exclude_scopes = getattr(opts, "model.resume_exclude_scopes", "")
exclude_scopes: List[str] = clean_strip(exclude_scopes)
missing_scopes = getattr(opts, "model.ignore_missing_scopes", "")
missing_scopes: List[str] = clean_strip(missing_scopes)
rename_scopes_map: List[List[str]] = getattr(opts, "model.rename_scopes_map", [])
if rename_scopes_map:
for entry in rename_scopes_map:
if len(entry) != 2:
raise ValueError(
"Every entry in model.rename_scopes_map must contain exactly two string elements"
" for before and after. Got {}.".format(str(entry))
)
# By default, adding scopes that we exclude to missing scopes
# If you excluded something, you can't expect it to be there.
missing_scopes += exclude_scopes
# remove unwanted scopes
if exclude_scopes:
for key in wts.copy():
if any([re.match(x, key) for x in exclude_scopes]):
del wts[key]
if rename_scopes_map:
for before, after in rename_scopes_map:
wts = {re.sub(before, after, key): value for key, value in wts.items()}
strict = not bool(missing_scopes)
try:
module = unwrap_model_fn(model)
missing_keys, unexpected_keys = module.load_state_dict(wts, strict=strict)
if unexpected_keys:
raise Exception(
"Found unexpected keys: {}."
"You can ignore these keys using `model.resume_exclude_scopes`.".format(
",".join(unexpected_keys)
)
)
missing_keys = [
key
for key in missing_keys
if not any([re.match(x, key) for x in missing_scopes])
]
if missing_keys:
raise Exception(
"Missing keys detected. Did not find the following keys in pre-trained model: {}."
" You can ignore the keys using `model.ignore_missing_scopes`.".format(
",".join(missing_keys)
)
)
if is_master_node:
logger.log("Pretrained weights are loaded from {}".format(wt_loc))
except Exception as e:
if is_master_node:
logger.error(
"Unable to load pretrained weights from {}. Error: {}".format(wt_loc, e)
)
return model
[docs]def parameter_list(
named_parameters,
weight_decay: Optional[float] = 0.0,
no_decay_bn_filter_bias: Optional[bool] = False,
*args,
**kwargs,
) -> List[Dict]:
module_name = kwargs.get("module_name", "")
with_decay = []
without_decay = []
with_decay_param_names = []
without_decay_param_names = []
if isinstance(named_parameters, list):
for n_parameter in named_parameters:
for p_name, param in n_parameter():
if (
param.requires_grad
and len(param.shape) == 1
and no_decay_bn_filter_bias
):
# biases and normalization layer parameters are of len 1
without_decay.append(param)
without_decay_param_names.append(module_name + p_name)
elif param.requires_grad:
with_decay.append(param)
with_decay_param_names.append(module_name + p_name)
else:
for p_name, param in named_parameters():
if (
param.requires_grad
and len(param.shape) == 1
and no_decay_bn_filter_bias
):
# biases and normalization layer parameters are of len 1
without_decay.append(param)
without_decay_param_names.append(module_name + p_name)
elif param.requires_grad:
with_decay.append(param)
with_decay_param_names.append(module_name + p_name)
param_list = [
{
"params": with_decay,
"weight_decay": weight_decay,
"param_names": with_decay_param_names,
}
]
if len(without_decay) > 0:
param_list.append(
{
"params": without_decay,
"weight_decay": 0.0,
"param_names": without_decay_param_names,
}
)
return param_list
[docs]def freeze_module(module: torch.nn.Module, force_eval: bool = True) -> torch.nn.Module:
"""
Sets requires_grad = False on all the given module parameters, and put the module in eval mode.
By default, it also overrides the module's `train` method to make sure that it always stays in eval mode
(ie calling ``module.train(mode=True)`` executes ``module.train(mode=False)``)
>>> module = nn.Linear(10, 20).train()
>>> module.training
True
>>> module.weight.requires_grad
True
>>> freeze_module(module).train().training
False
>>> module.weight.requires_grad
False
"""
module.eval()
for parameter in module.parameters():
parameter.requires_grad = False
if force_eval:
def _force_train_in_eval(
self: torch.nn.Module, mode: bool = True
) -> torch.nn.Module:
# ignore train/eval calls: perpetually stays in eval
return self
module.train = MethodType(_force_train_in_eval, module)
return module
[docs]def freeze_modules_based_on_opts(
opts: argparse.Namespace, model: torch.nn.Module, verbose: bool = True
) -> torch.nn.Module:
"""
Allows for freezing immediate modules and parameters of the model using --model.freeze-modules.
--model.freeze-modules should be a list of strings or a comma-separated list of regex expressions.
Examples of --model.freeze-modules:
"conv.*" # see example below: can freeze all (top-level) conv layers
"^((?!classifier).)*$" # freezes everything except for "classifier": useful for linear probing
"conv1,layer1,layer2,layer3" # freeze all layers up to layer3
>>> model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20, 64, 5)),
('relu2', nn.ReLU())
]))
>>> opts = argparse.Namespace(**{"model.freeze_modules": "conv1"})
>>> _ = freeze_modules_based_on_opts(opts, model)
INFO - Freezing module: conv1
>>> model.train()
>>> model.conv1.training
False
>>> model.conv2.training
True
"""
freeze_patterns = getattr(opts, "model.freeze_modules", "")
freeze_patterns = clean_strip(freeze_patterns)
verbose = verbose and is_master(opts)
if freeze_patterns:
# TODO: allow applying on all modules, not just immediate chidren? How?
for name, module in model.named_children():
if any([re.match(p, name) for p in freeze_patterns]):
freeze_module(module)
if verbose:
logger.info("Freezing module: {}".format(name))
for name, param in model.named_parameters(recurse=False):
if any([re.match(p, name) for p in freeze_patterns]):
param.requires_grad = False
if verbose:
logger.info("Freezing parameter: {}".format(name))
if verbose and hasattr(model, "get_trainable_parameters"):
param_list, _ = model.get_trainable_parameters()
for params in param_list:
if (
not isinstance(params["param_names"], List)
or not isinstance(params["params"], List)
or not isinstance(params["weight_decay"], (float, int))
):
param_types = {k: type(v) for k, v in params.items()}
logger.error(
"Expected parameter format: {{ params: List, weight_decay: float, param_names: List }}. "
"Got: {}".format(param_types)
)
# Flatten all parameter names
trainable_param_names = [p for x in param_list for p in x["param_names"]]
logger.info("Trainable parameters: {}".format(trainable_param_names))
return model
[docs]def get_tensor_sizes(data: Union[Dict, Tensor]) -> Union[List[str], List[Tuple[int]]]:
"""Utility function for extracting tensor shapes (for printing purposes only)."""
if isinstance(data, Dict):
tensor_sizes = []
for k, v in data.items():
size_ = get_tensor_sizes(v)
if size_:
tensor_sizes.append(f"{k}: {size_}")
return tensor_sizes
elif isinstance(data, Tensor):
return [*data.shape]
else:
return []