├── .gitignore ├── README.md └── looksam.py /.gitignore: -------------------------------------------------------------------------------- 1 | Skip to content 2 | Search or jump to… 3 | Pull requests 4 | Issues 5 | Marketplace 6 | Explore 7 | 8 | @rollovd 9 | davda54 10 | / 11 | sam 12 | Public 13 | Code 14 | Issues 15 | 2 16 | Pull requests 17 | 1 18 | Actions 19 | Projects 20 | Wiki 21 | Security 22 | Insights 23 | sam/.gitignore 24 | @davda54 25 | davda54 Initial commit 26 | Latest commit 62d91b9 on 25 Oct 2020 27 | History 28 | 1 contributor 29 | 129 lines (105 sloc) 1.76 KB 30 | 31 | # Byte-compiled / optimized / DLL files 32 | __pycache__/ 33 | *.py[cod] 34 | *$py.class 35 | 36 | # C extensions 37 | *.so 38 | 39 | .idea/.gitignore 40 | .idea/deployment.xml 41 | .idea/inspectionProfiles/profiles_settings.xml 42 | .idea/looksam.iml 43 | .idea/misc.xml 44 | .idea/modules.xml 45 | .idea/vcs.xml 46 | 47 | 48 | # Distribution / packaging 49 | .Python 50 | build/ 51 | develop-eggs/ 52 | dist/ 53 | downloads/ 54 | eggs/ 55 | .eggs/ 56 | lib/ 57 | lib64/ 58 | parts/ 59 | sdist/ 60 | var/ 61 | wheels/ 62 | pip-wheel-metadata/ 63 | share/python-wheels/ 64 | *.egg-info/ 65 | .installed.cfg 66 | *.egg 67 | MANIFEST 68 | 69 | # PyInstaller 70 | # Usually these files are written by a python script from a template 71 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 72 | *.manifest 73 | *.spec 74 | 75 | # Installer logs 76 | pip-log.txt 77 | pip-delete-this-directory.txt 78 | 79 | # Unit test / coverage reports 80 | htmlcov/ 81 | .tox/ 82 | .nox/ 83 | .coverage 84 | .coverage.* 85 | .cache 86 | nosetests.xml 87 | coverage.xml 88 | *.cover 89 | *.py,cover 90 | .hypothesis/ 91 | .pytest_cache/ 92 | 93 | # Translations 94 | *.mo 95 | *.pot 96 | 97 | # Django stuff: 98 | *.log 99 | local_settings.py 100 | db.sqlite3 101 | db.sqlite3-journal 102 | 103 | # Flask stuff: 104 | instance/ 105 | .webassets-cache 106 | 107 | # Scrapy stuff: 108 | .scrapy 109 | 110 | # Sphinx documentation 111 | docs/_build/ 112 | 113 | # PyBuilder 114 | target/ 115 | 116 | # Jupyter Notebook 117 | .ipynb_checkpoints 118 | 119 | # IPython 120 | profile_default/ 121 | ipython_config.py 122 | 123 | # pyenv 124 | .python-version 125 | 126 | # pipenv 127 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 128 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 129 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 130 | # install all needed dependencies. 131 | #Pipfile.lock 132 | 133 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 134 | __pypackages__/ 135 | 136 | # Celery stuff 137 | celerybeat-schedule 138 | celerybeat.pid 139 | 140 | # SageMath parsed files 141 | *.sage.py 142 | 143 | # Environments 144 | .env 145 | .venv 146 | env/ 147 | venv/ 148 | ENV/ 149 | env.bak/ 150 | venv.bak/ 151 | 152 | # Spyder project settings 153 | .spyderproject 154 | .spyproject 155 | 156 | # Rope project settings 157 | .ropeproject 158 | 159 | # mkdocs documentation 160 | /site 161 | 162 | # mypy 163 | .mypy_cache/ 164 | .dmypy.json 165 | dmypy.json 166 | 167 | # Pyre type checker 168 | .pyre/ 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
4 | ~ in Pytorch ~ 5 |
6 | 7 | LookSAM is an accelerated [SAM](https://arxiv.org/pdf/2010.01412.pdf) algorithm. Instead of computing the inner gradient 8 | ascent every step, LookSAM computer it periodically and reuses the direction that promotes to flat regions. 9 | 10 | This is unofficial repository for [Towards Efficient and Scalable Sharpness-Aware Minimization](https://arxiv.org/pdf/2203.02714.pdf). 11 | Currently it is only proposed an algorithm without layer-wise adaptive rates (but it will be soon...). 12 | 13 | In rewritten `step` method you are able to fed several arguments: 14 | 1. `t` is a train_index to define index of current batch; 15 | 2. `samples` are input data; 16 | 3. `targets` are input ground-truth data; 17 | 4. `zero_sam_grad` is a boolean value to zero gradients under SAM condition (first step) (see discussion [here](https://github.com/rollovd/LookSAM/issues/3) ; 18 | 5. `zero_grad` is a boolean value for zero gradient after second step; 19 | 20 | [Unofficial SAM repo](https://github.com/davda54/sam/blob/main/README.md) is my inspiration :) 21 | 22 | ## Usage 23 | 24 | ```python 25 | from looksam import LookSAM 26 | 27 | 28 | model = YourModel() 29 | criterion = YourCriterion() 30 | base_optimizer = YourBaseOptimizer 31 | loader = YourLoader() 32 | 33 | optimizer = LookSAM( 34 | k=10, 35 | alpha=0.7, 36 | model=model, 37 | base_optimizer=base_optimizer, 38 | rho=0.1, 39 | **kwargs 40 | ) 41 | 42 | ... 43 | 44 | model.train() 45 | 46 | for train_index, (samples, targets) in enumerate(loader): 47 | ... 48 | 49 | loss = criterion(model(samples), targets) 50 | loss.backward() 51 | optimizer.step( 52 | t=train_index, 53 | samples=samples, 54 | targets=targets, 55 | zero_sam_grad=True, 56 | zero_grad=True 57 | ) 58 | ... 59 | 60 | ``` 61 | -------------------------------------------------------------------------------- /looksam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Any, Callable 4 | 5 | class LookSAM(torch.optim.Optimizer): 6 | 7 | def __init__(self, 8 | k: int, 9 | alpha: float, 10 | model: nn.Module, 11 | base_optimizer: torch.optim.Optimizer, 12 | criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 13 | rho: float = 0.05, 14 | **kwargs: Any 15 | ): 16 | 17 | """ 18 | LookSAM algorithm: https://arxiv.org/pdf/2203.02714.pdf 19 | Optimization algorithm that capable of simultaneously minimizing loss and loss sharpness to narrow 20 | the generalization gap. 21 | 22 | :param k: frequency of SAM's gradient calculation (default: 10) 23 | :param model: your network 24 | :param criterion: your loss function 25 | :param base_optimizer: optimizer module (SGD, Adam, etc...) 26 | :param alpha: scaling factor for the adaptive ratio (default: 0.7) 27 | :param rho: radius of the l_p ball (default: 0.1) 28 | 29 | :return: None 30 | 31 | Usage: 32 | model = YourModel() 33 | criterion = YourCriterion() 34 | base_optimizer = YourBaseOptimizer 35 | optimizer = LookSAM(k=k, 36 | alpha=alpha, 37 | model=model, 38 | base_optimizer=base_optimizer, 39 | criterion=criterion, 40 | rho=rho, 41 | **kwargs) 42 | 43 | ... 44 | 45 | for train_index, (samples, targets) in enumerate(loader): 46 | loss = criterion(model(samples), targets) 47 | loss.backward() 48 | optimizer.step(t=train_index, samples=samples, targets=targets, zero_sam_grad=True, zero_grad=True) 49 | 50 | ... 51 | 52 | """ 53 | 54 | defaults = dict(alpha=alpha, rho=rho, **kwargs) 55 | self.model = model 56 | super(LookSAM, self).__init__(self.model.parameters(), defaults) 57 | 58 | self.k = k 59 | self.alpha = torch.tensor(alpha, requires_grad=False) 60 | self.criterion = criterion 61 | 62 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 63 | self.param_groups = self.base_optimizer.param_groups 64 | self.criterion = criterion 65 | 66 | @staticmethod 67 | def normalized(g): 68 | return g / g.norm(p=2) 69 | 70 | def step(self, t, samples, targets, zero_sam_grad=True, zero_grad=True): 71 | if not t % self.k: 72 | group = self.param_groups[0] 73 | scale = group['rho'] / (self._grad_norm() + 1e-8) 74 | 75 | for index_p, p in enumerate(group['params']): 76 | if p.grad is None: 77 | continue 78 | 79 | self.state[p]['old_p'] = p.data.clone() 80 | self.state[f'old_grad_p_{index_p}']['old_grad_p'] = p.grad.clone() 81 | 82 | with torch.no_grad(): 83 | e_w = p.grad * scale.to(p) 84 | p.add_(e_w) 85 | 86 | if zero_sam_grad: 87 | self.zero_grad() 88 | 89 | self.criterion(self.model(samples), targets).backward() 90 | 91 | group = self.param_groups[0] 92 | for index_p, p in enumerate(group['params']): 93 | if p.grad is None: 94 | continue 95 | if not t % self.k: 96 | old_grad_p = self.state[f'old_grad_p_{index_p}']['old_grad_p'] 97 | g_grad_norm = LookSAM.normalized(old_grad_p) 98 | g_s_grad_norm = LookSAM.normalized(p.grad) 99 | self.state[f'gv_{index_p}']['gv'] = torch.sub(p.grad, p.grad.norm(p=2) * torch.sum( 100 | g_grad_norm * g_s_grad_norm) * g_grad_norm) 101 | 102 | else: 103 | with torch.no_grad(): 104 | gv = self.state[f'gv_{index_p}']['gv'] 105 | p.grad.add_(self.alpha.to(p) * (p.grad.norm(p=2) / (gv.norm(p=2) + 1e-8) * gv)) 106 | 107 | p.data = self.state[p]['old_p'] 108 | 109 | self.base_optimizer.step() 110 | if zero_grad: 111 | self.zero_grad() 112 | 113 | def _grad_norm(self): 114 | shared_device = self.param_groups[0]['params'][0].device 115 | norm = torch.norm( 116 | torch.stack([ 117 | p.grad.norm(p=2).to(shared_device) for group in self.param_groups for p in group['params'] 118 | if p.grad is not None 119 | ]), 120 | p=2 121 | ) 122 | 123 | return norm 124 | --------------------------------------------------------------------------------