This post will try to breakdown the process of how to implement any optimizer ( in this case tiger ) in fast.ai.

Primer on how to implement an optimizer in fast.ai

Setup

import matplotlib as mpl
mpl.rcParams['image.cmap'] = 'gray'

from __future__ import annotations
from fastai.torch_basics import *
from fastai.vision.all import *
from torcheval.metrics import MulticlassAccuracy
from torch import tensor
from datasets import load_dataset
dsd = load_dataset('fashion_mnist', trust_remote_code=True)
xl,yl = 'image','label'
name  = "fashion_mnist"
dsd   = load_dataset(name)
class FashionMnistTransform(Transform):
  def __init__(self, xl): self.xl = xl
  def encodes(self, o): return TensorImage(o[self.xl])

class Lblr(Transform):
  def __init__(self, yl): self.yl = yl
  def encodes(self, o): return o[self.yl]
tfms = [[FashionMnistTransform(xl), IntToFloatTensor()],
        [Lblr(yl), Categorize()]]
dsets = Datasets(dsd['train'], tfms, splits=RandomSplitter(seed=42)(dsd['train']))
dls = dsets.dataloaders(bs=64)
x, y = dls.one_batch()
x.shape, y.shape
dls.show_batch()
def conv(ni, nf, ks=3, stride=2, act=True):
    res = nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

class BasicModel(Module):
  def __init__(self):
    self.m = nn.Sequential(conv(1 ,8),
                           conv(8 ,16),
                           conv(16,32),
                           conv(32,64),
                           conv(64,10, act=False), nn.Flatten()
                           )

  def forward(self, x):
    x = x.unsqueeze(1)
    return self.m(x)
bm = BasicModel()
o = bm(x)
learn = Learner(dls, bm, metrics=accuracy)
learn.fit_one_cycle(3)

image.png

Delving deep into the Optimizers in fast.ai

fast.ai provides an optimizer class which takes in params and a list of callbacks.

class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self,
        params:Tensor|Iterable, # Model parameters
        cbs:callable|MutableSequence, # `Optimizer` step callbacks
        **defaults # Hyper parameters default values
    ):

params will be used to create the param_groups of the optimizer.cbs is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay.

Additionally, each cb can have a defaults attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the defaults kwargs. The steppers will be called by Optimizer.step (which is the standard PyTorch name), and gradients can be cleared with Optimizer.zero_grad (also a standard PyTorch name).

Once the defaults have all been pulled off, they are copied as many times as there are param_groups and stored in hypers. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjust those values after the init.

Example of how to implement SGD with momentum.

\begin{aligned} \text{Update Step:} \quad & \mathbf{v}_{t+1} = \mu \mathbf{v}_t - \eta \left( \nabla_{\mathbf{w}} \mathcal{L}(\mathbf{w}_t) + \lambda \mathbf{w}_t \right) \\ & \mathbf{w}_{t+1} = \mathbf{w}_t + \mathbf{v}_{t+1} \end{aligned}
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd and wd!=0: p.data.mul_(1 - lr*wd)

weight_decay.defaults = dict(wd=0.)

def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    grad_avg.mul_(mom).add_(p.grad.data, alpha=damp)
    return {'grad_avg': grad_avg}

average_grad.defaults = dict(mom=0.9)

def momentum_step(p, lr, grad_avg, **kwargs):
    "Step for SGD with momentum with `lr`"
    p.data.add_(grad_avg, alpha=-lr)

def SGD(
    params:Tensor|Iterable, # Model parameters
    lr:float|slice, # Default learning rate
    mom:float=0., # Gradient moving average (β1) coefficient
    wd:Real=0., # Optional weight decay (true or L2)
    decouple_wd:bool=True # Apply true weight decay or L2 regularization (SGD)
) -> Optimizer:
    "A SGD `Optimizer`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    if mom != 0: cbs.append(average_grad)
    cbs.append(sgd_step if mom==0 else momentum_step)
    return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)

Here are the equations inolved for tight-fisted optimizer

\begin{align*} m_t &= \beta m_{t-1} + (1 - \beta) g_t \\ \theta_t &= \theta_{t-1} - \eta_t \left[ \operatorname{sign}(m_t) + \lambda_t \theta_{t-1} \right] \end{align*}
def tiger_step(p, lr, grad_avg, wd, **kwargs):
    p.data.add_(grad_avg.sign(), alpha=-lr)
    return p

def Tiger(
    params:Tensor|Iterable, # Model parameters
    lr:float|slice, # Default learning rate
    mom:float=0.945, # Gradient moving average (β) coefficient
    wd:Real=0.01, # Optional weight decay (true or L2)
    decouple_wd:bool=True # Apply weight decay
) -> Optimizer:
    "A Tight-fisted ( Tiger ) `Optimizer`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [partial(average_grad, dampening=True), tiger_step]
    return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)
learn_tiger = Learner(dls, bm, metrics=accuracy, opt_func=Tiger)
learn_tiger.lr_find()

image.png

learn_tiger.fit_one_cycle(3, 3e-3)

image.png