Source code for torchwrench.optim.utils

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from torch import nn
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer

pylog = logging.getLogger(__name__)


[docs] def get_lr(optim: Optimizer, idx: int = 0, key: str = "lr") -> float: """ Get the learning rate of the first group of an optimizer. Args: optim: The optimizer to get. idx: The group index of the learning rate in the optimizer. defaults to 0. """ return get_lrs(optim, key)[idx]
[docs] def get_lrs(optim: Optimizer, key: str = "lr") -> List[float]: """ Get the learning rates in all groups of an optimizer. Args: optim: The optimizer to get. """ return [group[key] for group in optim.param_groups]
[docs] def create_params_groups_bias( model_or_params: Union[nn.Module, Iterable[Tuple[str, Parameter]]], weight_decay: float, *, skip_list: Optional[Iterable[str]] = (), verbose: int = 2, ) -> List[Dict[str, Union[List[Parameter], float]]]: """Split parameters into 2 groups with or without weight decay for AdamW optimizer. Example ======= ``` >>> model = nn.Linear(100, 10) >>> weight_decay = 0.01 >>> param_groups = create_params_groups_bias(model, weight_decay=weight_decay) >>> optimizer = AdamW(params_groups, weight_decay=weight_decay) ``` """ named_params: Iterable[Tuple[str, Parameter]] if isinstance(model_or_params, nn.Module): # trick to avoid named_parameters() typing error x: Any = model_or_params.named_parameters() named_params = x # type: ignore else: named_params = model_or_params del model_or_params decay: List[Parameter] = [] no_decay: List[Parameter] = [] if skip_list is None: skip_list = {} else: skip_list = dict.fromkeys(skip_list) for name, param in named_params: if not param.requires_grad: continue if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: no_decay.append(param) if verbose >= 2: pylog.debug(f"No wd for {name}") else: decay.append(param) return [ {"params": decay, "weight_decay": weight_decay}, {"params": no_decay, "weight_decay": 0.0}, ]