LookSAM
This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization.
Install / Use
/learn @rollovd/LookSAMREADME
<h1 align="center"><b>LookSAM Optimizer</b></h1>
<h3 align="center"><b>Towards Efficient and Scalable Sharpness-Aware Minimization</b></h3>
<p align="center">
<i>~ in Pytorch ~</i>
</p>
LookSAM is an accelerated SAM algorithm. Instead of computing the inner gradient ascent every step, LookSAM computer it periodically and reuses the direction that promotes to flat regions.
This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization. Currently it is only proposed an algorithm without layer-wise adaptive rates (but it will be soon...).
In rewritten step method you are able to fed several arguments:
tis a train_index to define index of current batch;samplesare input data;targetsare input ground-truth data;zero_sam_gradis a boolean value to zero gradients under SAM condition (first step) (see discussion here ;zero_gradis a boolean value for zero gradient after second step;
Unofficial SAM repo is my inspiration :)
Usage
from looksam import LookSAM
model = YourModel()
criterion = YourCriterion()
base_optimizer = YourBaseOptimizer
loader = YourLoader()
optimizer = LookSAM(
k=10,
alpha=0.7,
model=model,
base_optimizer=base_optimizer,
rho=0.1,
**kwargs
)
...
model.train()
for train_index, (samples, targets) in enumerate(loader):
...
loss = criterion(model(samples), targets)
loss.backward()
optimizer.step(
t=train_index,
samples=samples,
targets=targets,
zero_sam_grad=True,
zero_grad=True
)
...
