Source code for torchwrench.optim.utils
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from torch import nn
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
pylog = logging.getLogger(__name__)
[docs]
def get_lr(optim: Optimizer, idx: int = 0, key: str = "lr") -> float:
"""
Get the learning rate of the first group of an optimizer.
Args:
optim: The optimizer to get.
idx: The group index of the learning rate in the optimizer. defaults to 0.
"""
return get_lrs(optim, key)[idx]
[docs]
def get_lrs(optim: Optimizer, key: str = "lr") -> List[float]:
"""
Get the learning rates in all groups of an optimizer.
Args:
optim: The optimizer to get.
"""
return [group[key] for group in optim.param_groups]
[docs]
def create_params_groups_bias(
model_or_params: Union[nn.Module, Iterable[Tuple[str, Parameter]]],
weight_decay: float,
*,
skip_list: Optional[Iterable[str]] = (),
verbose: int = 2,
) -> List[Dict[str, Union[List[Parameter], float]]]:
"""Split parameters into 2 groups with or without weight decay for AdamW optimizer.
Example
=======
```
>>> model = nn.Linear(100, 10)
>>> weight_decay = 0.01
>>> param_groups = create_params_groups_bias(model, weight_decay=weight_decay)
>>> optimizer = AdamW(params_groups, weight_decay=weight_decay)
```
"""
named_params: Iterable[Tuple[str, Parameter]]
if isinstance(model_or_params, nn.Module):
# trick to avoid named_parameters() typing error
x: Any = model_or_params.named_parameters()
named_params = x # type: ignore
else:
named_params = model_or_params
del model_or_params
decay: List[Parameter] = []
no_decay: List[Parameter] = []
if skip_list is None:
skip_list = {}
else:
skip_list = dict.fromkeys(skip_list)
for name, param in named_params:
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
no_decay.append(param)
if verbose >= 2:
pylog.debug(f"No wd for {name}")
else:
decay.append(param)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]