├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── avg_checkpoints.py ├── clean_checkpoint.py ├── convert └── convert_from_mxnet.py ├── distributed_train.sh ├── hubconf.py ├── inference.py ├── notebooks ├── EffResNetComparison.ipynb └── GeneralizationToImageNetV2.ipynb ├── requirements.txt ├── results ├── README.md ├── results-imagenet-a.csv ├── results-imagenet.csv ├── results-imagenetv2-matched-frequency.csv └── results-sketch.csv ├── setup.py ├── sotabench.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── random_erasing.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── feature_hooks.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── helpers.py │ ├── hrnet.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── avg_pool2d_same.py │ │ ├── cbam.py │ │ ├── cond_conv2d.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── helpers.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── padding.py │ │ ├── se.py │ │ ├── selective_kernel.py │ │ ├── split_batchnorm.py │ │ └── test_time_pool.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── pnasnet.py │ ├── registry.py │ ├── res2net.py │ ├── resnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ └── xception.py ├── optim │ ├── __init__.py │ ├── adamw.py │ ├── lookahead.py │ ├── nadam.py │ ├── novograd.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ └── rmsprop_tf.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── plateau_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils.py └── version.py ├── train.py └── validate.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | # PyTorch weights 104 | *.tar 105 | *.pth 106 | *.gz 107 | Untitled.ipynb 108 | Testing notebook.ipynb 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sample [Neural-Backed Decision Trees](https://github.com/alvinwan/neural-backed-decision-trees) Integration with [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) 2 | 3 | [Project Page](http://nbdt.alvinwan.com)  //  [Paper](http://nbdt.alvinwan.com/paper/)  //  [No-code Web Demo](http://nbdt.alvinwan.com/demo/)  //  [Colab Notebook](https://colab.research.google.com/github/alvinwan/neural-backed-decision-trees/blob/master/examples/load_pretrained_nbdts.ipynb) 4 | 5 | Wondering what neural-backed decision trees are? See the [Neural-Backed Decision Trees](https://github.com/alvinwan/neural-backed-decision-trees) repository. 6 | 7 | **Table of Contents** 8 | 9 | - [Explanation](#explanation) 10 | - [Training and Evaluation](#training-and-evaluation) 11 | - [Results](#results) 12 | 13 | ## Explanation 14 | 15 | The full diff between the original repository `pytorch-image-models` and the integrated version is [here, using Github's compare view](https://github.com/alvinwan/nbdt-pytorch-image-models/compare/nbdt). There are a total of 9 lines added: 16 | 17 | 1. Generate hierarchy (0 lines): Start by generating an induced hierarchy. We use a hierarchy induced from EfficientNet-B7. 18 | 19 | ```bash 20 | nbdt-hierarchy --dataset=Imagenet1000 --arch=efficientnet_b7b 21 | ``` 22 | 23 | 2. Wrap loss during training (3 lines): In `train.py`, we add the custom loss function. This is a wrapper around the existing loss functions. 24 | 25 | ```python 26 | from nbdt.loss import SoftTreeSupLoss 27 | train_loss_fn = SoftTreeSupLoss(criterion=train_loss_fn, dataset='Imagenet1000', tree_supervision_weight=10, hierarchy='induced-efficientnet_b7b') 28 | validate_loss_fn = SoftTreeSupLoss(criterion=validate_loss_fn, dataset='Imagenet1000', tree_supervision_weight=10, hierarchy='induced-efficientnet_b7b') 29 | ``` 30 | 31 | 3. Wrap model during inference (6 lines): In `validate.py`, we add NBDT inference. This is a wrapper around the existing model. We actually spend 4 lines adding and processing a custom `--nbdt` argument, so the actual logic for adding NBDT inference is only 2 lines. 32 | 33 | ```python 34 | parser.add_argument('--nbdt', choices=('none', 'soft', 'hard'), default='none', 35 | help='Type of NBDT inference to run') 36 | ... 37 | from nbdt.model import SoftNBDT, HardNBDT 38 | if args.nbdt != 'none': 39 | cls = SoftNBDT if args.nbdt == 'soft' else HardNBDT 40 | model = cls(model=model, dataset='Imagenet1000', hierarchy='induced-efficientnet_b7b') 41 | ``` 42 | 43 | ## Training and Evaluation 44 | 45 | To reproduce our results, **make sure to checkout the `nbdt` branch**. 46 | 47 | ```bash 48 | # 1. git clone the repository 49 | git clone git@github.com:alvinwan/nbdt-pytorch-image-models.git # or http addr if you don't have private-public github key setup 50 | cd nbdt-pytorch-image-models 51 | 52 | # 2. install requirements 53 | pip install -r requirements.txt 54 | 55 | # 3. checkout branch with nbdt integration 56 | git checkout nbdt 57 | ``` 58 | 59 | **Training**: For our ImageNet results, we use the hyperparameter settings reported for ImageNet-EdgeTPU-Small found in the original README: [EfficientNet-ES (EdgeTPU-Small)](https://github.com/rwightman/pytorch-image-models#efficientnet-es-edgetpu-small-with-randaugment---78066-top-1-93926-top-5). Note the accuracy reported at this link is the average of 8 checkpoints. However, we use only 1 checkpoint, so we compare against the best single-checkpoint 77.23% result for EfficientNet-ES reported in the official [EfficientNet-EdgeTPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu) repository. The hyperparameter settings reported in the first link are reproduced below: 60 | 61 | ```bash 62 | ./distributed_train.sh 8 /data/imagenetwhole/ilsvrc2012/ --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 63 | ``` 64 | 65 | **Validation**: To run inference, we use the following command. The majority of this command is typical for this repository. We simply add an extra `--nbdt` flag at the end for the type of NBDT we wish to run. 66 | 67 | ```bash 68 | python validate.py /data/imagenetwhole/ilsvrc2012/val/ --model efficientnet_es --checkpoint=./output/train/20200319-185245-efficientnet_es-224/model_best.pth.tar --nbdt=soft 69 | ``` 70 | 71 | ## Results 72 | 73 | NofE, shown below, was the strongest competing decision-tree-based method. Note that our NBDT-S outperforms NofE by ~14%. The acccuracy of the original neural network, EfficientNet-ES, is also shown. Our decision tree's accuracy is within 2% of the original neural network's accuracy. 74 | 75 | | | NBDT-S (Ours) | NBDT-H (Ours) | NofE | EfficientNet-ES | 76 | |----------------|---------------|---------------|--------|-----------------| 77 | | ImageNet Top-1 | 75.30% | 74.79% | 61.29% | 77.23% | 78 | 79 | See the original Neural-Backed Decision Trees [results](https://github.com/alvinwan/neural-backed-decision-trees#results) for a full list of all baselines. You can download our pretrained model and all associated logs at [v1.0](https://github.com/alvinwan/nbdt-pytorch-image-models/releases/tag/1.0). 80 | 81 | **For more information, return to the original [Neural-Backed Decision Trees](https://github.com/alvinwan/neural-backed-decision-trees) repository.** 82 | -------------------------------------------------------------------------------- /avg_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Checkpoint Averaging Script 3 | 4 | This script averages all model weights for checkpoints in specified path that match 5 | the specified filter wildcard. All checkpoints must be from the exact same model. 6 | 7 | For any hope of decent results, the checkpoints should be from the same or child 8 | (via resumes) training session. This can be viewed as similar to maintaining running 9 | EMA (exponential moving average) of the model weights or performing SWA (stochastic 10 | weight averaging), but post-training. 11 | 12 | Hacked together by Ross Wightman (https://github.com/rwightman) 13 | """ 14 | import torch 15 | import argparse 16 | import os 17 | import glob 18 | import hashlib 19 | from timm.models.helpers import load_state_dict 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') 22 | parser.add_argument('--input', default='', type=str, metavar='PATH', 23 | help='path to base input folder containing checkpoints') 24 | parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', 25 | help='checkpoint filter (path wildcard)') 26 | parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', 27 | help='output filename') 28 | parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', 29 | help='Force not using ema version of weights (if present)') 30 | parser.add_argument('--no-sort', dest='no_sort', action='store_true', 31 | help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') 32 | parser.add_argument('-n', type=int, default=10, metavar='N', 33 | help='Number of checkpoints to average') 34 | 35 | 36 | def checkpoint_metric(checkpoint_path): 37 | if not checkpoint_path or not os.path.isfile(checkpoint_path): 38 | return {} 39 | print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) 40 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 41 | metric = None 42 | if 'metric' in checkpoint: 43 | metric = checkpoint['metric'] 44 | return metric 45 | 46 | 47 | def main(): 48 | args = parser.parse_args() 49 | # by default use the EMA weights (if present) 50 | args.use_ema = not args.no_use_ema 51 | # by default sort by checkpoint metric (if present) and avg top n checkpoints 52 | args.sort = not args.no_sort 53 | 54 | if os.path.exists(args.output): 55 | print("Error: Output filename ({}) already exists.".format(args.output)) 56 | exit(1) 57 | 58 | pattern = args.input 59 | if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep): 60 | pattern += os.path.sep 61 | pattern += args.filter 62 | checkpoints = glob.glob(pattern, recursive=True) 63 | 64 | if args.sort: 65 | checkpoint_metrics = [] 66 | for c in checkpoints: 67 | metric = checkpoint_metric(c) 68 | if metric is not None: 69 | checkpoint_metrics.append((metric, c)) 70 | checkpoint_metrics = list(sorted(checkpoint_metrics)) 71 | checkpoint_metrics = checkpoint_metrics[-args.n:] 72 | print("Selected checkpoints:") 73 | [print(m, c) for m, c in checkpoint_metrics] 74 | avg_checkpoints = [c for m, c in checkpoint_metrics] 75 | else: 76 | avg_checkpoints = checkpoints 77 | print("Selected checkpoints:") 78 | [print(c) for c in checkpoints] 79 | 80 | avg_state_dict = {} 81 | avg_counts = {} 82 | for c in avg_checkpoints: 83 | new_state_dict = load_state_dict(c, args.use_ema) 84 | if not new_state_dict: 85 | print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) 86 | continue 87 | 88 | for k, v in new_state_dict.items(): 89 | if k not in avg_state_dict: 90 | avg_state_dict[k] = v.clone().to(dtype=torch.float64) 91 | avg_counts[k] = 1 92 | else: 93 | avg_state_dict[k] += v.to(dtype=torch.float64) 94 | avg_counts[k] += 1 95 | 96 | for k, v in avg_state_dict.items(): 97 | v.div_(avg_counts[k]) 98 | 99 | # float32 overflow seems unlikely based on weights seen to date, but who knows 100 | float32_info = torch.finfo(torch.float32) 101 | final_state_dict = {} 102 | for k, v in avg_state_dict.items(): 103 | v = v.clamp(float32_info.min, float32_info.max) 104 | final_state_dict[k] = v.to(dtype=torch.float32) 105 | 106 | torch.save(final_state_dict, args.output) 107 | with open(args.output, 'rb') as f: 108 | sha_hash = hashlib.sha256(f.read()).hexdigest() 109 | print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /clean_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Checkpoint Cleaning Script 3 | 4 | Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. 5 | and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 6 | calculation for model zoo compatibility. 7 | 8 | Hacked together by Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import torch 11 | import argparse 12 | import os 13 | import hashlib 14 | import shutil 15 | from collections import OrderedDict 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') 18 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 19 | help='path to latest checkpoint (default: none)') 20 | parser.add_argument('--output', default='', type=str, metavar='PATH', 21 | help='output path') 22 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 23 | help='use ema version of weights if present') 24 | parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', 25 | help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') 26 | 27 | _TEMP_NAME = './_checkpoint.pth' 28 | 29 | 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | if os.path.exists(args.output): 34 | print("Error: Output filename ({}) already exists.".format(args.output)) 35 | exit(1) 36 | 37 | # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save 38 | if args.checkpoint and os.path.isfile(args.checkpoint): 39 | print("=> Loading checkpoint '{}'".format(args.checkpoint)) 40 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 41 | 42 | new_state_dict = OrderedDict() 43 | if isinstance(checkpoint, dict): 44 | state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' 45 | if state_dict_key in checkpoint: 46 | state_dict = checkpoint[state_dict_key] 47 | else: 48 | state_dict = checkpoint 49 | else: 50 | assert False 51 | for k, v in state_dict.items(): 52 | if args.clean_aux_bn and 'aux_bn' in k: 53 | # If all aux_bn keys are removed, the SplitBN layers will end up as normal and 54 | # load with the unmodified model using BatchNorm2d. 55 | continue 56 | name = k[7:] if k.startswith('module') else k 57 | new_state_dict[name] = v 58 | print("=> Loaded state_dict from '{}'".format(args.checkpoint)) 59 | 60 | torch.save(new_state_dict, _TEMP_NAME) 61 | with open(_TEMP_NAME, 'rb') as f: 62 | sha_hash = hashlib.sha256(f.read()).hexdigest() 63 | 64 | if args.output: 65 | checkpoint_root, checkpoint_base = os.path.split(args.output) 66 | checkpoint_base = os.path.splitext(checkpoint_base)[0] 67 | else: 68 | checkpoint_root = '' 69 | checkpoint_base = os.path.splitext(args.checkpoint)[0] 70 | final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' 71 | shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) 72 | print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) 73 | else: 74 | print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /convert/convert_from_mxnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import os 4 | 5 | import mxnet as mx 6 | import gluoncv 7 | import torch 8 | from timm import create_model 9 | 10 | parser = argparse.ArgumentParser(description='Convert from MXNet') 11 | parser.add_argument('--model', default='all', type=str, metavar='MODEL', 12 | help='Name of model to train (default: "all"') 13 | 14 | 15 | def convert(mxnet_name, torch_name): 16 | # download and load the pre-trained model 17 | net = gluoncv.model_zoo.get_model(mxnet_name, pretrained=True) 18 | 19 | # create corresponding torch model 20 | torch_net = create_model(torch_name) 21 | 22 | mxp = [(k, v) for k, v in net.collect_params().items() if 'running' not in k] 23 | torchp = list(torch_net.named_parameters()) 24 | torch_params = {} 25 | 26 | # convert parameters 27 | # NOTE: we are relying on the fact that the order of parameters 28 | # are usually exactly the same between these models, thus no key name mapping 29 | # is necessary. Asserts will trip if this is not the case. 30 | for (tn, tv), (mn, mv) in zip(torchp, mxp): 31 | m_split = mn.split('_') 32 | t_split = tn.split('.') 33 | print(t_split, m_split) 34 | print(tv.shape, mv.shape) 35 | 36 | # ensure ordering of BN params match since their sizes are not specific 37 | if m_split[-1] == 'gamma': 38 | assert t_split[-1] == 'weight' 39 | if m_split[-1] == 'beta': 40 | assert t_split[-1] == 'bias' 41 | 42 | # ensure shapes match 43 | assert all(t == m for t, m in zip(tv.shape, mv.shape)) 44 | 45 | torch_tensor = torch.from_numpy(mv.data().asnumpy()) 46 | torch_params[tn] = torch_tensor 47 | 48 | # convert buffers (batch norm running stats) 49 | mxb = [(k, v) for k, v in net.collect_params().items() if any(x in k for x in ['running_mean', 'running_var'])] 50 | torchb = [(k, v) for k, v in torch_net.named_buffers() if 'num_batches' not in k] 51 | for (tn, tv), (mn, mv) in zip(torchb, mxb): 52 | print(tn, mn) 53 | print(tv.shape, mv.shape) 54 | 55 | # ensure ordering of BN params match since their sizes are not specific 56 | if 'running_var' in tn: 57 | assert 'running_var' in mn 58 | if 'running_mean' in tn: 59 | assert 'running_mean' in mn 60 | 61 | torch_tensor = torch.from_numpy(mv.data().asnumpy()) 62 | torch_params[tn] = torch_tensor 63 | 64 | torch_net.load_state_dict(torch_params) 65 | torch_filename = './%s.pth' % torch_name 66 | torch.save(torch_net.state_dict(), torch_filename) 67 | with open(torch_filename, 'rb') as f: 68 | sha_hash = hashlib.sha256(f.read()).hexdigest() 69 | final_filename = os.path.splitext(torch_filename)[0] + '-' + sha_hash[:8] + '.pth' 70 | os.rename(torch_filename, final_filename) 71 | print("=> Saved converted model to '{}, SHA256: {}'".format(final_filename, sha_hash)) 72 | 73 | 74 | def map_mx_to_torch_model(mx_name): 75 | torch_name = mx_name.lower() 76 | if torch_name.startswith('se_'): 77 | torch_name = torch_name.replace('se_', 'se') 78 | elif torch_name.startswith('senet_'): 79 | torch_name = torch_name.replace('senet_', 'senet') 80 | elif torch_name.startswith('inceptionv3'): 81 | torch_name = torch_name.replace('inceptionv3', 'inception_v3') 82 | torch_name = 'gluon_' + torch_name 83 | return torch_name 84 | 85 | 86 | ALL = ['resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 'resnet101_v1b', 'resnet152_v1b', 87 | 'resnet50_v1c', 'resnet101_v1c', 'resnet152_v1c', 'resnet50_v1d', 'resnet101_v1d', 'resnet152_v1d', 88 | #'resnet50_v1e', 'resnet101_v1e', 'resnet152_v1e', 89 | 'resnet50_v1s', 'resnet101_v1s', 'resnet152_v1s', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 90 | 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnext101_64x4d', 'senet_154', 'inceptionv3'] 91 | 92 | 93 | def main(): 94 | args = parser.parse_args() 95 | 96 | if not args.model or args.model == 'all': 97 | for mx_model in ALL: 98 | torch_model = map_mx_to_torch_model(mx_model) 99 | convert(mx_model, torch_model) 100 | else: 101 | mx_model = args.model 102 | torch_model = map_mx_to_torch_model(mx_model) 103 | convert(mx_model, torch_model) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | from timm.models import registry 3 | 4 | globals().update(registry._model_entrypoints) 5 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """PyTorch Inference Script 3 | 4 | An example inference script that outputs top-k class ids for images in a folder into a csv. 5 | 6 | Hacked together by Ross Wightman (https://github.com/rwightman) 7 | """ 8 | import os 9 | import time 10 | import argparse 11 | import logging 12 | import numpy as np 13 | import torch 14 | 15 | from timm.models import create_model, apply_test_time_pool 16 | from timm.data import Dataset, create_loader, resolve_data_config 17 | from timm.utils import AverageMeter, setup_default_logging 18 | 19 | torch.backends.cudnn.benchmark = True 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') 22 | parser.add_argument('data', metavar='DIR', 23 | help='path to dataset') 24 | parser.add_argument('--output_dir', metavar='DIR', default='./', 25 | help='path to output files') 26 | parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', 27 | help='model architecture (default: dpn92)') 28 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 29 | help='number of data loading workers (default: 2)') 30 | parser.add_argument('-b', '--batch-size', default=256, type=int, 31 | metavar='N', help='mini-batch size (default: 256)') 32 | parser.add_argument('--img-size', default=224, type=int, 33 | metavar='N', help='Input image dimension') 34 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 35 | help='Override mean pixel value of dataset') 36 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 37 | help='Override std deviation of of dataset') 38 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 39 | help='Image resize interpolation type (overrides model)') 40 | parser.add_argument('--num-classes', type=int, default=1000, 41 | help='Number classes in dataset') 42 | parser.add_argument('--log-freq', default=10, type=int, 43 | metavar='N', help='batch logging frequency (default: 10)') 44 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 45 | help='path to latest checkpoint (default: none)') 46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 47 | help='use pre-trained model') 48 | parser.add_argument('--num-gpu', type=int, default=1, 49 | help='Number of GPUS to use') 50 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', 51 | help='disable test time pool') 52 | parser.add_argument('--topk', default=5, type=int, 53 | metavar='N', help='Top-k to output to CSV') 54 | 55 | 56 | def main(): 57 | setup_default_logging() 58 | args = parser.parse_args() 59 | # might as well try to do something useful... 60 | args.pretrained = args.pretrained or not args.checkpoint 61 | 62 | # create model 63 | model = create_model( 64 | args.model, 65 | num_classes=args.num_classes, 66 | in_chans=3, 67 | pretrained=args.pretrained, 68 | checkpoint_path=args.checkpoint) 69 | 70 | logging.info('Model %s created, param count: %d' % 71 | (args.model, sum([m.numel() for m in model.parameters()]))) 72 | 73 | config = resolve_data_config(vars(args), model=model) 74 | model, test_time_pool = apply_test_time_pool(model, config, args) 75 | 76 | if args.num_gpu > 1: 77 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 78 | else: 79 | model = model.cuda() 80 | 81 | loader = create_loader( 82 | Dataset(args.data), 83 | input_size=config['input_size'], 84 | batch_size=args.batch_size, 85 | use_prefetcher=True, 86 | interpolation=config['interpolation'], 87 | mean=config['mean'], 88 | std=config['std'], 89 | num_workers=args.workers, 90 | crop_pct=1.0 if test_time_pool else config['crop_pct']) 91 | 92 | model.eval() 93 | 94 | k = min(args.topk, args.num_classes) 95 | batch_time = AverageMeter() 96 | end = time.time() 97 | topk_ids = [] 98 | with torch.no_grad(): 99 | for batch_idx, (input, _) in enumerate(loader): 100 | input = input.cuda() 101 | labels = model(input) 102 | topk = labels.topk(k)[1] 103 | topk_ids.append(topk.cpu().numpy()) 104 | 105 | # measure elapsed time 106 | batch_time.update(time.time() - end) 107 | end = time.time() 108 | 109 | if batch_idx % args.log_freq == 0: 110 | logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 111 | batch_idx, len(loader), batch_time=batch_time)) 112 | 113 | topk_ids = np.concatenate(topk_ids, axis=0).squeeze() 114 | 115 | with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: 116 | filenames = loader.dataset.filenames() 117 | for filename, label in zip(filenames, topk_ids): 118 | filename = os.path.basename(filename) 119 | out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( 120 | filename, label[0], label[1], label[2], label[3], label[4])) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchvision>=0.4.0 3 | pyyaml 4 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # Validation Results 2 | 3 | This folder contains validation results for the models in this collection having pretrained weights. Since the focus for this repository is currently ImageNet-1k classification, all of the results are based on datasets compatible with ImageNet-1k classes. 4 | 5 | ## Datasets 6 | 7 | There are currently results for the ImageNet validation set and 3 additional test sets. 8 | 9 | ### ImageNet Validation - [`results-imagenet.csv`](results-imagenet.csv) 10 | 11 | * Source: http://image-net.org/challenges/LSVRC/2012/index 12 | * Paper: "ImageNet Large Scale Visual Recognition Challenge" - https://arxiv.org/abs/1409.0575 13 | 14 | The standard 50,000 image ImageNet-1k validation set. Model selection during training utilizes this validation set, so it is not a true test set. Question: Does anyone have the official ImageNet-1k test set classification labels now that challenges are done? 15 | 16 | ### ImageNetV2 Matched Frequency - [`results-imagenetv2-matched-frequency.csv`](results-imagenetv2-matched-frequency.csv) 17 | 18 | * Source: https://github.com/modestyachts/ImageNetV2 19 | * Paper: "Do ImageNet Classifiers Generalize to ImageNet?" - https://arxiv.org/abs/1902.10811 20 | 21 | An ImageNet test set of 10,000 images sampled from new images roughly 10 years after the original. Care was taken to replicate the original ImageNet curation/sampling process. 22 | 23 | ### ImageNet-Sketch - [`results-sketch.csv`](results-sketch.csv) 24 | 25 | * Source: https://github.com/HaohanWang/ImageNet-Sketch 26 | * Paper: "Learning Robust Global Representations by Penalizing Local Predictive Power" - https://arxiv.org/abs/1905.13549 27 | 28 | 50,000 non photographic (or photos of such) images (sketches, doodles, mostly monochromatic) covering all 1000 ImageNet classes. 29 | 30 | ### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv) 31 | 32 | * Source: https://github.com/hendrycks/natural-adv-examples 33 | * Paper: "Natural Adversarial Examples" - https://arxiv.org/abs/1907.07174 34 | 35 | A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1. 36 | 37 | ## TODO 38 | * Add rank difference, and top-1/top-5 difference from ImageNet-1k validation for the 3 additional test sets 39 | * Explore adding a reduced version of ImageNet-C (Corruptions) and ImageNet-P (Perturbations) from https://github.com/hendrycks/robustness. The originals are huge and image size specific. 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('timm/version.py').read()) 14 | setup( 15 | name='timm', 16 | version=__version__, 17 | description='(Unofficial) PyTorch Image Models', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/pytorch-image-models', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Topic :: Scientific/Engineering', 35 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 36 | 'Topic :: Software Development', 37 | 'Topic :: Software Development :: Libraries', 38 | 'Topic :: Software Development :: Libraries :: Python Modules', 39 | ], 40 | 41 | # Note that this is a string of words separated by whitespace, not a list. 42 | keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', 43 | packages=find_packages(exclude=['convert']), 44 | install_requires=['torch >= 1.0', 'torchvision'], 45 | python_requires='>=3.6', 46 | ) 47 | -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint 3 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .config import resolve_data_config 3 | from .dataset import Dataset, DatasetTar, AugMixDataset 4 | from .transforms import * 5 | from .loader import create_loader 6 | from .transforms_factory import create_transform 7 | from .mixup import mixup_batch, FastCollateMixup 8 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 9 | rand_augment_transform, auto_augment_transform 10 | -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | def resolve_data_config(args, default_cfg={}, model=None, verbose=True): 6 | new_config = {} 7 | default_cfg = default_cfg 8 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 9 | default_cfg = model.default_cfg 10 | 11 | # Resolve input/image size 12 | in_chans = 3 13 | if 'chans' in args and args['chans'] is not None: 14 | in_chans = args['chans'] 15 | 16 | input_size = (in_chans, 224, 224) 17 | if 'input_size' in args and args['input_size'] is not None: 18 | assert isinstance(args['input_size'], (tuple, list)) 19 | assert len(args['input_size']) == 3 20 | input_size = tuple(args['input_size']) 21 | in_chans = input_size[0] # input_size overrides in_chans 22 | elif 'img_size' in args and args['img_size'] is not None: 23 | assert isinstance(args['img_size'], int) 24 | input_size = (in_chans, args['img_size'], args['img_size']) 25 | elif 'input_size' in default_cfg: 26 | input_size = default_cfg['input_size'] 27 | new_config['input_size'] = input_size 28 | 29 | # resolve interpolation method 30 | new_config['interpolation'] = 'bicubic' 31 | if 'interpolation' in args and args['interpolation']: 32 | new_config['interpolation'] = args['interpolation'] 33 | elif 'interpolation' in default_cfg: 34 | new_config['interpolation'] = default_cfg['interpolation'] 35 | 36 | # resolve dataset + model mean for normalization 37 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 38 | if 'model' in args: 39 | new_config['mean'] = get_mean_by_model(args['model']) 40 | if 'mean' in args and args['mean'] is not None: 41 | mean = tuple(args['mean']) 42 | if len(mean) == 1: 43 | mean = tuple(list(mean) * in_chans) 44 | else: 45 | assert len(mean) == in_chans 46 | new_config['mean'] = mean 47 | elif 'mean' in default_cfg: 48 | new_config['mean'] = default_cfg['mean'] 49 | 50 | # resolve dataset + model std deviation for normalization 51 | new_config['std'] = IMAGENET_DEFAULT_STD 52 | if 'model' in args: 53 | new_config['std'] = get_std_by_model(args['model']) 54 | if 'std' in args and args['std'] is not None: 55 | std = tuple(args['std']) 56 | if len(std) == 1: 57 | std = tuple(list(std) * in_chans) 58 | else: 59 | assert len(std) == in_chans 60 | new_config['std'] = std 61 | elif 'std' in default_cfg: 62 | new_config['std'] = default_cfg['std'] 63 | 64 | # resolve default crop percentage 65 | new_config['crop_pct'] = DEFAULT_CROP_PCT 66 | if 'crop_pct' in args and args['crop_pct'] is not None: 67 | new_config['crop_pct'] = args['crop_pct'] 68 | elif 'crop_pct' in default_cfg: 69 | new_config['crop_pct'] = default_cfg['crop_pct'] 70 | 71 | if verbose: 72 | logging.info('Data processing configuration for current model + dataset:') 73 | for n, v in new_config.items(): 74 | logging.info('\t%s: %s' % (n, str(v))) 75 | 76 | return new_config 77 | 78 | 79 | def get_mean_by_model(model_name): 80 | model_name = model_name.lower() 81 | if 'dpn' in model_name: 82 | return IMAGENET_DPN_STD 83 | elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): 84 | return IMAGENET_INCEPTION_MEAN 85 | else: 86 | return IMAGENET_DEFAULT_MEAN 87 | 88 | 89 | def get_std_by_model(model_name): 90 | model_name = model_name.lower() 91 | if 'dpn' in model_name: 92 | return IMAGENET_DEFAULT_STD 93 | elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): 94 | return IMAGENET_INCEPTION_STD 95 | else: 96 | return IMAGENET_DEFAULT_STD 97 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch.utils.data as data 6 | 7 | import os 8 | import re 9 | import torch 10 | import tarfile 11 | from PIL import Image 12 | 13 | 14 | IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] 15 | 16 | 17 | def natural_key(string_): 18 | """See http://www.codinghorror.com/blog/archives/001018.html""" 19 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 20 | 21 | 22 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 23 | labels = [] 24 | filenames = [] 25 | for root, subdirs, files in os.walk(folder, topdown=False): 26 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 27 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 28 | for f in files: 29 | base, ext = os.path.splitext(f) 30 | if ext.lower() in types: 31 | filenames.append(os.path.join(root, f)) 32 | labels.append(label) 33 | if class_to_idx is None: 34 | # building class index 35 | unique_labels = set(labels) 36 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 37 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 38 | images_and_targets = zip(filenames, [class_to_idx[l] for l in labels]) 39 | if sort: 40 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 41 | return images_and_targets, class_to_idx 42 | 43 | 44 | def load_class_map(filename, root=''): 45 | class_to_idx = {} 46 | class_map_path = filename 47 | if not os.path.exists(class_map_path): 48 | class_map_path = os.path.join(root, filename) 49 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename 50 | class_map_ext = os.path.splitext(filename)[-1].lower() 51 | if class_map_ext == '.txt': 52 | with open(class_map_path) as f: 53 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 54 | else: 55 | assert False, 'Unsupported class map extension' 56 | return class_to_idx 57 | 58 | 59 | class Dataset(data.Dataset): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | load_bytes=False, 65 | transform=None, 66 | class_map=''): 67 | 68 | class_to_idx = None 69 | if class_map: 70 | class_to_idx = load_class_map(class_map, root) 71 | images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 72 | if len(images) == 0: 73 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 74 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 75 | self.root = root 76 | self.samples = images 77 | self.imgs = self.samples # torchvision ImageFolder compat 78 | self.class_to_idx = class_to_idx 79 | self.load_bytes = load_bytes 80 | self.transform = transform 81 | 82 | def __getitem__(self, index): 83 | path, target = self.samples[index] 84 | img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | if target is None: 88 | target = torch.zeros(1).long() 89 | return img, target 90 | 91 | def __len__(self): 92 | return len(self.imgs) 93 | 94 | def filenames(self, indices=[], basename=False): 95 | if indices: 96 | if basename: 97 | return [os.path.basename(self.samples[i][0]) for i in indices] 98 | else: 99 | return [self.samples[i][0] for i in indices] 100 | else: 101 | if basename: 102 | return [os.path.basename(x[0]) for x in self.samples] 103 | else: 104 | return [x[0] for x in self.samples] 105 | 106 | 107 | def _extract_tar_info(tarfile, class_to_idx=None, sort=True): 108 | files = [] 109 | labels = [] 110 | for ti in tarfile.getmembers(): 111 | if not ti.isfile(): 112 | continue 113 | dirname, basename = os.path.split(ti.path) 114 | label = os.path.basename(dirname) 115 | ext = os.path.splitext(basename)[1] 116 | if ext.lower() in IMG_EXTENSIONS: 117 | files.append(ti) 118 | labels.append(label) 119 | if class_to_idx is None: 120 | unique_labels = set(labels) 121 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 122 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 123 | tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels]) 124 | if sort: 125 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 126 | return tarinfo_and_targets, class_to_idx 127 | 128 | 129 | class DatasetTar(data.Dataset): 130 | 131 | def __init__(self, root, load_bytes=False, transform=None, class_map=''): 132 | 133 | class_to_idx = None 134 | if class_map: 135 | class_to_idx = load_class_map(class_map, root) 136 | assert os.path.isfile(root) 137 | self.root = root 138 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 139 | self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx) 140 | self.tarfile = None # lazy init in __getitem__ 141 | self.load_bytes = load_bytes 142 | self.transform = transform 143 | 144 | def __getitem__(self, index): 145 | if self.tarfile is None: 146 | self.tarfile = tarfile.open(self.root) 147 | tarinfo, target = self.samples[index] 148 | iob = self.tarfile.extractfile(tarinfo) 149 | img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') 150 | if self.transform is not None: 151 | img = self.transform(img) 152 | if target is None: 153 | target = torch.zeros(1).long() 154 | return img, target 155 | 156 | def __len__(self): 157 | return len(self.samples) 158 | 159 | 160 | class AugMixDataset(torch.utils.data.Dataset): 161 | """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" 162 | 163 | def __init__(self, dataset, num_splits=2): 164 | self.augmentation = None 165 | self.normalize = None 166 | self.dataset = dataset 167 | if self.dataset.transform is not None: 168 | self._set_transforms(self.dataset.transform) 169 | self.num_splits = num_splits 170 | 171 | def _set_transforms(self, x): 172 | assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' 173 | self.dataset.transform = x[0] 174 | self.augmentation = x[1] 175 | self.normalize = x[2] 176 | 177 | @property 178 | def transform(self): 179 | return self.dataset.transform 180 | 181 | @transform.setter 182 | def transform(self, x): 183 | self._set_transforms(x) 184 | 185 | def _normalize(self, x): 186 | return x if self.normalize is None else self.normalize(x) 187 | 188 | def __getitem__(self, i): 189 | x, y = self.dataset[i] # all splits share the same dataset base transform 190 | x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) 191 | # run the full augmentation on the remaining splits 192 | for _ in range(self.num_splits - 1): 193 | x_list.append(self._normalize(self.augmentation(x))) 194 | return tuple(x_list), y 195 | 196 | def __len__(self): 197 | return len(self.dataset) 198 | -------------------------------------------------------------------------------- /timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /timm/data/loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import numpy as np 3 | 4 | from .transforms_factory import create_transform 5 | from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 6 | from .distributed_sampler import OrderedDistributedSampler 7 | from .random_erasing import RandomErasing 8 | from .mixup import FastCollateMixup 9 | 10 | 11 | def fast_collate(batch): 12 | """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" 13 | assert isinstance(batch[0], tuple) 14 | batch_size = len(batch) 15 | if isinstance(batch[0][0], tuple): 16 | # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position 17 | # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position 18 | inner_tuple_size = len(batch[0][0]) 19 | flattened_batch_size = batch_size * inner_tuple_size 20 | targets = torch.zeros(flattened_batch_size, dtype=torch.int64) 21 | tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) 22 | for i in range(batch_size): 23 | assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length 24 | for j in range(inner_tuple_size): 25 | targets[i + j * batch_size] = batch[i][1] 26 | tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) 27 | return tensor, targets 28 | elif isinstance(batch[0][0], np.ndarray): 29 | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) 30 | assert len(targets) == batch_size 31 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 32 | for i in range(batch_size): 33 | tensor[i] += torch.from_numpy(batch[i][0]) 34 | return tensor, targets 35 | elif isinstance(batch[0][0], torch.Tensor): 36 | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) 37 | assert len(targets) == batch_size 38 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 39 | for i in range(batch_size): 40 | tensor[i].copy_(batch[i][0]) 41 | return tensor, targets 42 | else: 43 | assert False 44 | 45 | 46 | class PrefetchLoader: 47 | 48 | def __init__(self, 49 | loader, 50 | mean=IMAGENET_DEFAULT_MEAN, 51 | std=IMAGENET_DEFAULT_STD, 52 | fp16=False, 53 | re_prob=0., 54 | re_mode='const', 55 | re_count=1, 56 | re_num_splits=0): 57 | self.loader = loader 58 | self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) 59 | self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) 60 | self.fp16 = fp16 61 | if fp16: 62 | self.mean = self.mean.half() 63 | self.std = self.std.half() 64 | if re_prob > 0.: 65 | self.random_erasing = RandomErasing( 66 | probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) 67 | else: 68 | self.random_erasing = None 69 | 70 | def __iter__(self): 71 | stream = torch.cuda.Stream() 72 | first = True 73 | 74 | for next_input, next_target in self.loader: 75 | with torch.cuda.stream(stream): 76 | next_input = next_input.cuda(non_blocking=True) 77 | next_target = next_target.cuda(non_blocking=True) 78 | if self.fp16: 79 | next_input = next_input.half().sub_(self.mean).div_(self.std) 80 | else: 81 | next_input = next_input.float().sub_(self.mean).div_(self.std) 82 | if self.random_erasing is not None: 83 | next_input = self.random_erasing(next_input) 84 | 85 | if not first: 86 | yield input, target 87 | else: 88 | first = False 89 | 90 | torch.cuda.current_stream().wait_stream(stream) 91 | input = next_input 92 | target = next_target 93 | 94 | yield input, target 95 | 96 | def __len__(self): 97 | return len(self.loader) 98 | 99 | @property 100 | def sampler(self): 101 | return self.loader.sampler 102 | 103 | @property 104 | def dataset(self): 105 | return self.loader.dataset 106 | 107 | @property 108 | def mixup_enabled(self): 109 | if isinstance(self.loader.collate_fn, FastCollateMixup): 110 | return self.loader.collate_fn.mixup_enabled 111 | else: 112 | return False 113 | 114 | @mixup_enabled.setter 115 | def mixup_enabled(self, x): 116 | if isinstance(self.loader.collate_fn, FastCollateMixup): 117 | self.loader.collate_fn.mixup_enabled = x 118 | 119 | 120 | def create_loader( 121 | dataset, 122 | input_size, 123 | batch_size, 124 | is_training=False, 125 | use_prefetcher=True, 126 | re_prob=0., 127 | re_mode='const', 128 | re_count=1, 129 | re_split=False, 130 | color_jitter=0.4, 131 | auto_augment=None, 132 | num_aug_splits=0, 133 | interpolation='bilinear', 134 | mean=IMAGENET_DEFAULT_MEAN, 135 | std=IMAGENET_DEFAULT_STD, 136 | num_workers=1, 137 | distributed=False, 138 | crop_pct=None, 139 | collate_fn=None, 140 | pin_memory=False, 141 | fp16=False, 142 | tf_preprocessing=False, 143 | ): 144 | re_num_splits = 0 145 | if re_split: 146 | # apply RE to second half of batch if no aug split otherwise line up with aug split 147 | re_num_splits = num_aug_splits or 2 148 | dataset.transform = create_transform( 149 | input_size, 150 | is_training=is_training, 151 | use_prefetcher=use_prefetcher, 152 | color_jitter=color_jitter, 153 | auto_augment=auto_augment, 154 | interpolation=interpolation, 155 | mean=mean, 156 | std=std, 157 | crop_pct=crop_pct, 158 | tf_preprocessing=tf_preprocessing, 159 | re_prob=re_prob, 160 | re_mode=re_mode, 161 | re_count=re_count, 162 | re_num_splits=re_num_splits, 163 | separate=num_aug_splits > 0, 164 | ) 165 | 166 | sampler = None 167 | if distributed: 168 | if is_training: 169 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 170 | else: 171 | # This will add extra duplicate entries to result in equal num 172 | # of samples per-process, will slightly alter validation results 173 | sampler = OrderedDistributedSampler(dataset) 174 | 175 | if collate_fn is None: 176 | collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate 177 | 178 | loader = torch.utils.data.DataLoader( 179 | dataset, 180 | batch_size=batch_size, 181 | shuffle=sampler is None and is_training, 182 | num_workers=num_workers, 183 | sampler=sampler, 184 | collate_fn=collate_fn, 185 | pin_memory=pin_memory, 186 | drop_last=is_training, 187 | ) 188 | if use_prefetcher: 189 | loader = PrefetchLoader( 190 | loader, 191 | mean=mean, 192 | std=std, 193 | fp16=fp16, 194 | re_prob=re_prob if is_training else 0., 195 | re_mode=re_mode, 196 | re_count=re_count, 197 | re_num_splits=re_num_splits 198 | ) 199 | 200 | return loader 201 | -------------------------------------------------------------------------------- /timm/data/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 6 | x = x.long().view(-1, 1) 7 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 8 | 9 | 10 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 11 | off_value = smoothing / num_classes 12 | on_value = 1. - smoothing + off_value 13 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 14 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 15 | return lam*y1 + (1. - lam)*y2 16 | 17 | 18 | def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): 19 | lam = 1. 20 | if not disable: 21 | lam = np.random.beta(alpha, alpha) 22 | input = input.mul(lam).add_(1 - lam, input.flip(0)) 23 | target = mixup_target(target, num_classes, lam, smoothing) 24 | return input, target 25 | 26 | 27 | class FastCollateMixup: 28 | 29 | def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): 30 | self.mixup_alpha = mixup_alpha 31 | self.label_smoothing = label_smoothing 32 | self.num_classes = num_classes 33 | self.mixup_enabled = True 34 | 35 | def __call__(self, batch): 36 | batch_size = len(batch) 37 | lam = 1. 38 | if self.mixup_enabled: 39 | lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 40 | 41 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 42 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 43 | 44 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 45 | for i in range(batch_size): 46 | mixed = batch[i][0].astype(np.float32) * lam + \ 47 | batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) 48 | np.round(mixed, out=mixed) 49 | tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) 50 | 51 | return tensor, target 52 | -------------------------------------------------------------------------------- /timm/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | 5 | 6 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 7 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 8 | # paths, flip the order so normal is run on CPU if this becomes a problem 9 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 10 | if per_pixel: 11 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 12 | elif rand_color: 13 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 14 | else: 15 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 16 | 17 | 18 | class RandomErasing: 19 | """ Randomly selects a rectangle region in an image and erases its pixels. 20 | 'Random Erasing Data Augmentation' by Zhong et al. 21 | See https://arxiv.org/pdf/1708.04896.pdf 22 | 23 | This variant of RandomErasing is intended to be applied to either a batch 24 | or single image tensor after it has been normalized by dataset mean and std. 25 | Args: 26 | probability: Probability that the Random Erasing operation will be performed. 27 | min_area: Minimum percentage of erased area wrt input image area. 28 | max_area: Maximum percentage of erased area wrt input image area. 29 | min_aspect: Minimum aspect ratio of erased area. 30 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 31 | 'const' - erase block is constant color of 0 for all channels 32 | 'rand' - erase block is same per-channel random (normal) color 33 | 'pixel' - erase block is per-pixel random (normal) color 34 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 35 | per-image count is randomly chosen between 1 and this value. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, 41 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 42 | self.probability = probability 43 | self.min_area = min_area 44 | self.max_area = max_area 45 | max_aspect = max_aspect or 1 / min_aspect 46 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 47 | self.min_count = min_count 48 | self.max_count = max_count or min_count 49 | self.num_splits = num_splits 50 | mode = mode.lower() 51 | self.rand_color = False 52 | self.per_pixel = False 53 | if mode == 'rand': 54 | self.rand_color = True # per block random normal 55 | elif mode == 'pixel': 56 | self.per_pixel = True # per pixel random normal 57 | else: 58 | assert not mode or mode == 'const' 59 | self.device = device 60 | 61 | def _erase(self, img, chan, img_h, img_w, dtype): 62 | if random.random() > self.probability: 63 | return 64 | area = img_h * img_w 65 | count = self.min_count if self.min_count == self.max_count else \ 66 | random.randint(self.min_count, self.max_count) 67 | for _ in range(count): 68 | for attempt in range(10): 69 | target_area = random.uniform(self.min_area, self.max_area) * area / count 70 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 71 | h = int(round(math.sqrt(target_area * aspect_ratio))) 72 | w = int(round(math.sqrt(target_area / aspect_ratio))) 73 | if w < img_w and h < img_h: 74 | top = random.randint(0, img_h - h) 75 | left = random.randint(0, img_w - w) 76 | img[:, top:top + h, left:left + w] = _get_pixels( 77 | self.per_pixel, self.rand_color, (chan, h, w), 78 | dtype=dtype, device=self.device) 79 | break 80 | 81 | def __call__(self, input): 82 | if len(input.size()) == 3: 83 | self._erase(input, *input.size(), input.dtype) 84 | else: 85 | batch_size, chan, img_h, img_w = input.size() 86 | # skip first slice of batch if num_splits is set (for clean portion of samples) 87 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 88 | for i in range(batch_start, batch_size): 89 | self._erase(input[i], chan, img_h, img_w, input.dtype) 90 | return input 91 | -------------------------------------------------------------------------------- /timm/data/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | from PIL import Image 4 | import warnings 5 | import math 6 | import random 7 | import numpy as np 8 | 9 | 10 | class ToNumpy: 11 | 12 | def __call__(self, pil_img): 13 | np_img = np.array(pil_img, dtype=np.uint8) 14 | if np_img.ndim < 3: 15 | np_img = np.expand_dims(np_img, axis=-1) 16 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 17 | return np_img 18 | 19 | 20 | class ToTensor: 21 | 22 | def __init__(self, dtype=torch.float32): 23 | self.dtype = dtype 24 | 25 | def __call__(self, pil_img): 26 | np_img = np.array(pil_img, dtype=np.uint8) 27 | if np_img.ndim < 3: 28 | np_img = np.expand_dims(np_img, axis=-1) 29 | np_img = np.rollaxis(np_img, 2) # HWC to CHW 30 | return torch.from_numpy(np_img).to(dtype=self.dtype) 31 | 32 | 33 | _pil_interpolation_to_str = { 34 | Image.NEAREST: 'PIL.Image.NEAREST', 35 | Image.BILINEAR: 'PIL.Image.BILINEAR', 36 | Image.BICUBIC: 'PIL.Image.BICUBIC', 37 | Image.LANCZOS: 'PIL.Image.LANCZOS', 38 | Image.HAMMING: 'PIL.Image.HAMMING', 39 | Image.BOX: 'PIL.Image.BOX', 40 | } 41 | 42 | 43 | def _pil_interp(method): 44 | if method == 'bicubic': 45 | return Image.BICUBIC 46 | elif method == 'lanczos': 47 | return Image.LANCZOS 48 | elif method == 'hamming': 49 | return Image.HAMMING 50 | else: 51 | # default bilinear, do we want to allow nearest? 52 | return Image.BILINEAR 53 | 54 | 55 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 56 | 57 | 58 | class RandomResizedCropAndInterpolation: 59 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 60 | 61 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 62 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 63 | is finally resized to given size. 64 | This is popularly used to train the Inception networks. 65 | 66 | Args: 67 | size: expected output size of each edge 68 | scale: range of size of the origin size cropped 69 | ratio: range of aspect ratio of the origin aspect ratio cropped 70 | interpolation: Default: PIL.Image.BILINEAR 71 | """ 72 | 73 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), 74 | interpolation='bilinear'): 75 | if isinstance(size, tuple): 76 | self.size = size 77 | else: 78 | self.size = (size, size) 79 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 80 | warnings.warn("range should be of kind (min, max)") 81 | 82 | if interpolation == 'random': 83 | self.interpolation = _RANDOM_INTERPOLATION 84 | else: 85 | self.interpolation = _pil_interp(interpolation) 86 | self.scale = scale 87 | self.ratio = ratio 88 | 89 | @staticmethod 90 | def get_params(img, scale, ratio): 91 | """Get parameters for ``crop`` for a random sized crop. 92 | 93 | Args: 94 | img (PIL Image): Image to be cropped. 95 | scale (tuple): range of size of the origin size cropped 96 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 97 | 98 | Returns: 99 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 100 | sized crop. 101 | """ 102 | area = img.size[0] * img.size[1] 103 | 104 | for attempt in range(10): 105 | target_area = random.uniform(*scale) * area 106 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 107 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 108 | 109 | w = int(round(math.sqrt(target_area * aspect_ratio))) 110 | h = int(round(math.sqrt(target_area / aspect_ratio))) 111 | 112 | if w <= img.size[0] and h <= img.size[1]: 113 | i = random.randint(0, img.size[1] - h) 114 | j = random.randint(0, img.size[0] - w) 115 | return i, j, h, w 116 | 117 | # Fallback to central crop 118 | in_ratio = img.size[0] / img.size[1] 119 | if in_ratio < min(ratio): 120 | w = img.size[0] 121 | h = int(round(w / min(ratio))) 122 | elif in_ratio > max(ratio): 123 | h = img.size[1] 124 | w = int(round(h * max(ratio))) 125 | else: # whole image 126 | w = img.size[0] 127 | h = img.size[1] 128 | i = (img.size[1] - h) // 2 129 | j = (img.size[0] - w) // 2 130 | return i, j, h, w 131 | 132 | def __call__(self, img): 133 | """ 134 | Args: 135 | img (PIL Image): Image to be cropped and resized. 136 | 137 | Returns: 138 | PIL Image: Randomly cropped and resized image. 139 | """ 140 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 141 | if isinstance(self.interpolation, (tuple, list)): 142 | interpolation = random.choice(self.interpolation) 143 | else: 144 | interpolation = self.interpolation 145 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 146 | 147 | def __repr__(self): 148 | if isinstance(self.interpolation, (tuple, list)): 149 | interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) 150 | else: 151 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 152 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 153 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 154 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 155 | format_string += ', interpolation={0})'.format(interpolate_str) 156 | return format_string 157 | 158 | 159 | -------------------------------------------------------------------------------- /timm/data/transforms_factory.py: -------------------------------------------------------------------------------- 1 | """ Transforms Factory 2 | Factory methods for building image transforms for use with TIMM (PyTorch Image Models) 3 | """ 4 | import math 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT 10 | from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform 11 | from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 12 | from timm.data.random_erasing import RandomErasing 13 | 14 | 15 | def transforms_imagenet_train( 16 | img_size=224, 17 | scale=(0.08, 1.0), 18 | color_jitter=0.4, 19 | auto_augment=None, 20 | interpolation='random', 21 | use_prefetcher=False, 22 | mean=IMAGENET_DEFAULT_MEAN, 23 | std=IMAGENET_DEFAULT_STD, 24 | re_prob=0., 25 | re_mode='const', 26 | re_count=1, 27 | re_num_splits=0, 28 | separate=False, 29 | ): 30 | """ 31 | If separate==True, the transforms are returned as a tuple of 3 separate transforms 32 | for use in a mixing dataset that passes 33 | * all data through the first (primary) transform, called the 'clean' data 34 | * a portion of the data through the secondary transform 35 | * normalizes and converts the branches above with the third, final transform 36 | """ 37 | primary_tfl = [ 38 | RandomResizedCropAndInterpolation( 39 | img_size, scale=scale, interpolation=interpolation), 40 | transforms.RandomHorizontalFlip() 41 | ] 42 | 43 | secondary_tfl = [] 44 | if auto_augment: 45 | assert isinstance(auto_augment, str) 46 | if isinstance(img_size, tuple): 47 | img_size_min = min(img_size) 48 | else: 49 | img_size_min = img_size 50 | aa_params = dict( 51 | translate_const=int(img_size_min * 0.45), 52 | img_mean=tuple([min(255, round(255 * x)) for x in mean]), 53 | ) 54 | if interpolation and interpolation != 'random': 55 | aa_params['interpolation'] = _pil_interp(interpolation) 56 | if auto_augment.startswith('rand'): 57 | secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] 58 | elif auto_augment.startswith('augmix'): 59 | aa_params['translate_pct'] = 0.3 60 | secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] 61 | else: 62 | secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] 63 | elif color_jitter is not None: 64 | # color jitter is enabled when not using AA 65 | if isinstance(color_jitter, (list, tuple)): 66 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 67 | # or 4 if also augmenting hue 68 | assert len(color_jitter) in (3, 4) 69 | else: 70 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 71 | color_jitter = (float(color_jitter),) * 3 72 | secondary_tfl += [transforms.ColorJitter(*color_jitter)] 73 | 74 | final_tfl = [] 75 | if use_prefetcher: 76 | # prefetcher and collate will handle tensor conversion and norm 77 | final_tfl += [ToNumpy()] 78 | else: 79 | final_tfl += [ 80 | transforms.ToTensor(), 81 | transforms.Normalize( 82 | mean=torch.tensor(mean), 83 | std=torch.tensor(std)) 84 | ] 85 | if re_prob > 0.: 86 | final_tfl.append( 87 | RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) 88 | 89 | if separate: 90 | return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) 91 | else: 92 | return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) 93 | 94 | 95 | def transforms_imagenet_eval( 96 | img_size=224, 97 | crop_pct=None, 98 | interpolation='bilinear', 99 | use_prefetcher=False, 100 | mean=IMAGENET_DEFAULT_MEAN, 101 | std=IMAGENET_DEFAULT_STD): 102 | crop_pct = crop_pct or DEFAULT_CROP_PCT 103 | 104 | if isinstance(img_size, tuple): 105 | assert len(img_size) == 2 106 | if img_size[-1] == img_size[-2]: 107 | # fall-back to older behaviour so Resize scales to shortest edge if target is square 108 | scale_size = int(math.floor(img_size[0] / crop_pct)) 109 | else: 110 | scale_size = tuple([int(x / crop_pct) for x in img_size]) 111 | else: 112 | scale_size = int(math.floor(img_size / crop_pct)) 113 | 114 | tfl = [ 115 | transforms.Resize(scale_size, _pil_interp(interpolation)), 116 | transforms.CenterCrop(img_size), 117 | ] 118 | if use_prefetcher: 119 | # prefetcher and collate will handle tensor conversion and norm 120 | tfl += [ToNumpy()] 121 | else: 122 | tfl += [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | 129 | return transforms.Compose(tfl) 130 | 131 | 132 | def create_transform( 133 | input_size, 134 | is_training=False, 135 | use_prefetcher=False, 136 | color_jitter=0.4, 137 | auto_augment=None, 138 | interpolation='bilinear', 139 | mean=IMAGENET_DEFAULT_MEAN, 140 | std=IMAGENET_DEFAULT_STD, 141 | re_prob=0., 142 | re_mode='const', 143 | re_count=1, 144 | re_num_splits=0, 145 | crop_pct=None, 146 | tf_preprocessing=False, 147 | separate=False): 148 | 149 | if isinstance(input_size, tuple): 150 | img_size = input_size[-2:] 151 | else: 152 | img_size = input_size 153 | 154 | if tf_preprocessing and use_prefetcher: 155 | assert not separate, "Separate transforms not supported for TF preprocessing" 156 | from timm.data.tf_preprocessing import TfPreprocessTransform 157 | transform = TfPreprocessTransform( 158 | is_training=is_training, size=img_size, interpolation=interpolation) 159 | else: 160 | if is_training: 161 | transform = transforms_imagenet_train( 162 | img_size, 163 | color_jitter=color_jitter, 164 | auto_augment=auto_augment, 165 | interpolation=interpolation, 166 | use_prefetcher=use_prefetcher, 167 | mean=mean, 168 | std=std, 169 | re_prob=re_prob, 170 | re_mode=re_mode, 171 | re_count=re_count, 172 | re_num_splits=re_num_splits, 173 | separate=separate) 174 | else: 175 | assert not separate, "Separate transforms not supported for validation preprocessing" 176 | transform = transforms_imagenet_eval( 177 | img_size, 178 | interpolation=interpolation, 179 | use_prefetcher=use_prefetcher, 180 | mean=mean, 181 | std=std, 182 | crop_pct=crop_pct) 183 | 184 | return transform 185 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 2 | from .jsd import JsdCrossEntropy -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x, target): 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception_v4 import * 2 | from .inception_resnet_v2 import * 3 | from .densenet import * 4 | from .resnet import * 5 | from .dpn import * 6 | from .senet import * 7 | from .xception import * 8 | from .nasnet import * 9 | from .pnasnet import * 10 | from .selecsls import * 11 | from .efficientnet import * 12 | from .mobilenetv3 import * 13 | from .inception_v3 import * 14 | from .gluon_resnet import * 15 | from .gluon_xception import * 16 | from .res2net import * 17 | from .dla import * 18 | from .hrnet import * 19 | from .sknet import * 20 | 21 | from .registry import * 22 | from .factory import create_model 23 | from .helpers import load_checkpoint, resume_checkpoint 24 | from .layers import TestTimePoolHead, apply_test_time_pool 25 | from .layers import convert_splitbn_model 26 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | 4 | 5 | def create_model( 6 | model_name, 7 | pretrained=False, 8 | num_classes=1000, 9 | in_chans=3, 10 | checkpoint_path='', 11 | **kwargs): 12 | """Create a model 13 | 14 | Args: 15 | model_name (str): name of model to instantiate 16 | pretrained (bool): load pretrained ImageNet-1k weights if true 17 | num_classes (int): number of classes for final fully connected layer (default: 1000) 18 | in_chans (int): number of input channels / colors (default: 3) 19 | checkpoint_path (str): path of checkpoint to load after model is initialized 20 | 21 | Keyword Args: 22 | drop_rate (float): dropout rate for training (default: 0.0) 23 | global_pool (str): global pool type (default: 'avg') 24 | **: other kwargs are model specific 25 | """ 26 | margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) 27 | 28 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 29 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 30 | if not is_efficientnet: 31 | kwargs.pop('bn_tf', None) 32 | kwargs.pop('bn_momentum', None) 33 | kwargs.pop('bn_eps', None) 34 | 35 | # Parameters that aren't supported by all models should default to None in command line args, 36 | # remove them if they are present and not set so that non-supporting models don't break. 37 | if kwargs.get('drop_block_rate', None) is None: 38 | kwargs.pop('drop_block_rate', None) 39 | 40 | # handle backwards compat with drop_connect -> drop_path change 41 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 42 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 43 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 44 | " Setting drop_path to %f." % drop_connect_rate) 45 | kwargs['drop_path_rate'] = drop_connect_rate 46 | 47 | if kwargs.get('drop_path_rate', None) is None: 48 | kwargs.pop('drop_path_rate', None) 49 | 50 | if is_model(model_name): 51 | create_fn = model_entrypoint(model_name) 52 | model = create_fn(**margs, **kwargs) 53 | else: 54 | raise RuntimeError('Unknown model (%s)' % model_name) 55 | 56 | if checkpoint_path: 57 | load_checkpoint(model, checkpoint_path) 58 | 59 | return model 60 | -------------------------------------------------------------------------------- /timm/models/feature_hooks.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, OrderedDict 2 | from functools import partial 3 | 4 | 5 | class FeatureHooks: 6 | 7 | def __init__(self, hooks, named_modules): 8 | # setup feature hooks 9 | modules = {k: v for k, v in named_modules} 10 | for h in hooks: 11 | hook_name = h['name'] 12 | m = modules[hook_name] 13 | hook_fn = partial(self._collect_output_hook, hook_name) 14 | if h['type'] == 'forward_pre': 15 | m.register_forward_pre_hook(hook_fn) 16 | elif h['type'] == 'forward': 17 | m.register_forward_hook(hook_fn) 18 | else: 19 | assert False, "Unsupported hook type" 20 | self._feature_outputs = defaultdict(OrderedDict) 21 | 22 | def _collect_output_hook(self, name, *args): 23 | x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre 24 | if isinstance(x, tuple): 25 | x = x[0] # unwrap input tuple 26 | self._feature_outputs[x.device][name] = x 27 | 28 | def get_output(self, device): 29 | output = tuple(self._feature_outputs[device].values())[::-1] 30 | self._feature_outputs[device] = OrderedDict() # clear after reading 31 | return output 32 | -------------------------------------------------------------------------------- /timm/models/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.model_zoo as model_zoo 3 | import os 4 | import logging 5 | from collections import OrderedDict 6 | 7 | 8 | def load_state_dict(checkpoint_path, use_ema=False): 9 | if checkpoint_path and os.path.isfile(checkpoint_path): 10 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 11 | state_dict_key = 'state_dict' 12 | if isinstance(checkpoint, dict): 13 | if use_ema and 'state_dict_ema' in checkpoint: 14 | state_dict_key = 'state_dict_ema' 15 | if state_dict_key and state_dict_key in checkpoint: 16 | new_state_dict = OrderedDict() 17 | for k, v in checkpoint[state_dict_key].items(): 18 | # strip `module.` prefix 19 | name = k[7:] if k.startswith('module') else k 20 | new_state_dict[name] = v 21 | state_dict = new_state_dict 22 | else: 23 | state_dict = checkpoint 24 | logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 25 | return state_dict 26 | else: 27 | logging.error("No checkpoint found at '{}'".format(checkpoint_path)) 28 | raise FileNotFoundError() 29 | 30 | 31 | def load_checkpoint(model, checkpoint_path, use_ema=False): 32 | state_dict = load_state_dict(checkpoint_path, use_ema) 33 | model.load_state_dict(state_dict) 34 | 35 | 36 | def resume_checkpoint(model, checkpoint_path): 37 | other_state = {} 38 | resume_epoch = None 39 | if os.path.isfile(checkpoint_path): 40 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 41 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 42 | new_state_dict = OrderedDict() 43 | for k, v in checkpoint['state_dict'].items(): 44 | name = k[7:] if k.startswith('module') else k 45 | new_state_dict[name] = v 46 | model.load_state_dict(new_state_dict) 47 | if 'optimizer' in checkpoint: 48 | other_state['optimizer'] = checkpoint['optimizer'] 49 | if 'amp' in checkpoint: 50 | other_state['amp'] = checkpoint['amp'] 51 | if 'epoch' in checkpoint: 52 | resume_epoch = checkpoint['epoch'] 53 | if 'version' in checkpoint and checkpoint['version'] > 1: 54 | resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save 55 | logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 56 | else: 57 | model.load_state_dict(checkpoint) 58 | logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) 59 | return other_state, resume_epoch 60 | else: 61 | logging.error("No checkpoint found at '{}'".format(checkpoint_path)) 62 | raise FileNotFoundError() 63 | 64 | 65 | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): 66 | if cfg is None: 67 | cfg = getattr(model, 'default_cfg') 68 | if cfg is None or 'url' not in cfg or not cfg['url']: 69 | logging.warning("Pretrained model URL is invalid, using random initialization.") 70 | return 71 | 72 | state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') 73 | 74 | if in_chans == 1: 75 | conv1_name = cfg['first_conv'] 76 | logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name) 77 | conv1_weight = state_dict[conv1_name + '.weight'] 78 | state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) 79 | elif in_chans != 3: 80 | assert False, "Invalid in_chans for pretrained weights" 81 | 82 | classifier_name = cfg['classifier'] 83 | if num_classes == 1000 and cfg['num_classes'] == 1001: 84 | # special case for imagenet trained models with extra background class in pretrained weights 85 | classifier_weight = state_dict[classifier_name + '.weight'] 86 | state_dict[classifier_name + '.weight'] = classifier_weight[1:] 87 | classifier_bias = state_dict[classifier_name + '.bias'] 88 | state_dict[classifier_name + '.bias'] = classifier_bias[1:] 89 | elif num_classes != cfg['num_classes']: 90 | # completely discard fully connected for all other differences between pretrained and created model 91 | del state_dict[classifier_name + '.weight'] 92 | del state_dict[classifier_name + '.bias'] 93 | strict = False 94 | 95 | if filter_fn is not None: 96 | state_dict = filter_fn(state_dict) 97 | 98 | model.load_state_dict(state_dict, strict=strict) 99 | 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /timm/models/inception_v3.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import Inception3 2 | from .registry import register_model 3 | from .helpers import load_pretrained 4 | from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 5 | 6 | __all__ = [] 7 | 8 | default_cfgs = { 9 | # original PyTorch weights, ported from Tensorflow but modified 10 | 'inception_v3': { 11 | 'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 12 | 'input_size': (3, 299, 299), 13 | 'crop_pct': 0.875, 14 | 'interpolation': 'bicubic', 15 | 'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults 16 | 'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults 17 | 'num_classes': 1000, 18 | 'first_conv': 'conv0', 19 | 'classifier': 'fc' 20 | }, 21 | # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 22 | 'tf_inception_v3': { 23 | 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', 24 | 'input_size': (3, 299, 299), 25 | 'crop_pct': 0.875, 26 | 'interpolation': 'bicubic', 27 | 'mean': IMAGENET_INCEPTION_MEAN, 28 | 'std': IMAGENET_INCEPTION_STD, 29 | 'num_classes': 1001, 30 | 'first_conv': 'conv0', 31 | 'classifier': 'fc' 32 | }, 33 | # my port of Tensorflow adversarially trained Inception V3 from 34 | # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz 35 | 'adv_inception_v3': { 36 | 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', 37 | 'input_size': (3, 299, 299), 38 | 'crop_pct': 0.875, 39 | 'interpolation': 'bicubic', 40 | 'mean': IMAGENET_INCEPTION_MEAN, 41 | 'std': IMAGENET_INCEPTION_STD, 42 | 'num_classes': 1001, 43 | 'first_conv': 'conv0', 44 | 'classifier': 'fc' 45 | }, 46 | # from gluon pretrained models, best performing in terms of accuracy/loss metrics 47 | # https://gluon-cv.mxnet.io/model_zoo/classification.html 48 | 'gluon_inception_v3': { 49 | 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', 50 | 'input_size': (3, 299, 299), 51 | 'crop_pct': 0.875, 52 | 'interpolation': 'bicubic', 53 | 'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults 54 | 'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults 55 | 'num_classes': 1000, 56 | 'first_conv': 'conv0', 57 | 'classifier': 'fc' 58 | } 59 | } 60 | 61 | 62 | def _assert_default_kwargs(kwargs): 63 | # for imported models (ie torchvision) without capability to change these params, 64 | # make sure they aren't being set to non-defaults 65 | assert kwargs.pop('global_pool', 'avg') == 'avg' 66 | assert kwargs.pop('drop_rate', 0.) == 0. 67 | 68 | 69 | @register_model 70 | def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): 71 | # original PyTorch weights, ported from Tensorflow but modified 72 | default_cfg = default_cfgs['inception_v3'] 73 | assert in_chans == 3 74 | _assert_default_kwargs(kwargs) 75 | model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False) 76 | if pretrained: 77 | load_pretrained(model, default_cfg, num_classes, in_chans) 78 | model.default_cfg = default_cfg 79 | return model 80 | 81 | 82 | @register_model 83 | def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): 84 | # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 85 | default_cfg = default_cfgs['tf_inception_v3'] 86 | assert in_chans == 3 87 | _assert_default_kwargs(kwargs) 88 | model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) 89 | if pretrained: 90 | load_pretrained(model, default_cfg, num_classes, in_chans) 91 | model.default_cfg = default_cfg 92 | return model 93 | 94 | 95 | @register_model 96 | def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): 97 | # my port of Tensorflow adversarially trained Inception V3 from 98 | # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz 99 | default_cfg = default_cfgs['adv_inception_v3'] 100 | assert in_chans == 3 101 | _assert_default_kwargs(kwargs) 102 | model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) 103 | if pretrained: 104 | load_pretrained(model, default_cfg, num_classes, in_chans) 105 | model.default_cfg = default_cfg 106 | return model 107 | 108 | 109 | @register_model 110 | def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs): 111 | # from gluon pretrained models, best performing in terms of accuracy/loss metrics 112 | # https://gluon-cv.mxnet.io/model_zoo/classification.html 113 | default_cfg = default_cfgs['gluon_inception_v3'] 114 | assert in_chans == 3 115 | _assert_default_kwargs(kwargs) 116 | model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False) 117 | if pretrained: 118 | load_pretrained(model, default_cfg, num_classes, in_chans) 119 | model.default_cfg = default_cfg 120 | return model 121 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .padding import get_padding 2 | from .avg_pool2d_same import AvgPool2dSame 3 | from .conv2d_same import Conv2dSame 4 | from .conv_bn_act import ConvBnAct 5 | from .mixed_conv2d import MixedConv2d 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .create_conv2d import create_conv2d 8 | from .create_attn import create_attn 9 | from .selective_kernel import SelectiveKernelConv 10 | from .se import SEModule 11 | from .eca import EcaModule, CecaModule 12 | from .activations import * 13 | from .adaptive_avgmax_pool import \ 14 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 16 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 17 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 18 | -------------------------------------------------------------------------------- /timm/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by Ross Wightman 7 | """ 8 | 9 | 10 | import torch 11 | from torch import nn as nn 12 | from torch.nn import functional as F 13 | 14 | 15 | _USE_MEM_EFFICIENT_ISH = True 16 | if _USE_MEM_EFFICIENT_ISH: 17 | # This version reduces memory overhead of Swish during training by 18 | # recomputing torch.sigmoid(x) in backward instead of saving it. 19 | @torch.jit.script 20 | def swish_jit_fwd(x): 21 | return x.mul(torch.sigmoid(x)) 22 | 23 | 24 | @torch.jit.script 25 | def swish_jit_bwd(x, grad_output): 26 | x_sigmoid = torch.sigmoid(x) 27 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 28 | 29 | 30 | class SwishJitAutoFn(torch.autograd.Function): 31 | """ torch.jit.script optimised Swish 32 | Inspired by conversation btw Jeremy Howard & Adam Pazske 33 | https://twitter.com/jeremyphoward/status/1188251041835315200 34 | """ 35 | 36 | @staticmethod 37 | def forward(ctx, x): 38 | ctx.save_for_backward(x) 39 | return swish_jit_fwd(x) 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output): 43 | x = ctx.saved_tensors[0] 44 | return swish_jit_bwd(x, grad_output) 45 | 46 | 47 | def swish(x, _inplace=False): 48 | return SwishJitAutoFn.apply(x) 49 | 50 | 51 | @torch.jit.script 52 | def mish_jit_fwd(x): 53 | return x.mul(torch.tanh(F.softplus(x))) 54 | 55 | 56 | @torch.jit.script 57 | def mish_jit_bwd(x, grad_output): 58 | x_sigmoid = torch.sigmoid(x) 59 | x_tanh_sp = F.softplus(x).tanh() 60 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 61 | 62 | 63 | class MishJitAutoFn(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, x): 66 | ctx.save_for_backward(x) 67 | return mish_jit_fwd(x) 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | x = ctx.saved_tensors[0] 72 | return mish_jit_bwd(x, grad_output) 73 | 74 | def mish(x, _inplace=False): 75 | return MishJitAutoFn.apply(x) 76 | 77 | else: 78 | def swish(x, inplace: bool = False): 79 | """Swish - Described in: https://arxiv.org/abs/1710.05941 80 | """ 81 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 82 | 83 | 84 | def mish(x, _inplace: bool = False): 85 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 86 | """ 87 | return x.mul(F.softplus(x).tanh()) 88 | 89 | 90 | class Swish(nn.Module): 91 | def __init__(self, inplace: bool = False): 92 | super(Swish, self).__init__() 93 | self.inplace = inplace 94 | 95 | def forward(self, x): 96 | return swish(x, self.inplace) 97 | 98 | 99 | class Mish(nn.Module): 100 | def __init__(self, inplace: bool = False): 101 | super(Mish, self).__init__() 102 | self.inplace = inplace 103 | 104 | def forward(self, x): 105 | return mish(x, self.inplace) 106 | 107 | 108 | def sigmoid(x, inplace: bool = False): 109 | return x.sigmoid_() if inplace else x.sigmoid() 110 | 111 | 112 | # PyTorch has this, but not with a consistent inplace argmument interface 113 | class Sigmoid(nn.Module): 114 | def __init__(self, inplace: bool = False): 115 | super(Sigmoid, self).__init__() 116 | self.inplace = inplace 117 | 118 | def forward(self, x): 119 | return x.sigmoid_() if self.inplace else x.sigmoid() 120 | 121 | 122 | def tanh(x, inplace: bool = False): 123 | return x.tanh_() if inplace else x.tanh() 124 | 125 | 126 | # PyTorch has this, but not with a consistent inplace argmument interface 127 | class Tanh(nn.Module): 128 | def __init__(self, inplace: bool = False): 129 | super(Tanh, self).__init__() 130 | self.inplace = inplace 131 | 132 | def forward(self, x): 133 | return x.tanh_() if self.inplace else x.tanh() 134 | 135 | 136 | def hard_swish(x, inplace: bool = False): 137 | inner = F.relu6(x + 3.).div_(6.) 138 | return x.mul_(inner) if inplace else x.mul(inner) 139 | 140 | 141 | class HardSwish(nn.Module): 142 | def __init__(self, inplace: bool = False): 143 | super(HardSwish, self).__init__() 144 | self.inplace = inplace 145 | 146 | def forward(self, x): 147 | return hard_swish(x, self.inplace) 148 | 149 | 150 | def hard_sigmoid(x, inplace: bool = False): 151 | if inplace: 152 | return x.add_(3.).clamp_(0., 6.).div_(6.) 153 | else: 154 | return F.relu6(x + 3.) / 6. 155 | 156 | 157 | class HardSigmoid(nn.Module): 158 | def __init__(self, inplace: bool = False): 159 | super(HardSigmoid, self).__init__() 160 | self.inplace = inplace 161 | 162 | def forward(self, x): 163 | return hard_sigmoid(x, self.inplace) 164 | 165 | -------------------------------------------------------------------------------- /timm/models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Author: Ross Wightman (rwightman) 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class AdaptiveAvgMaxPool2d(nn.Module): 53 | def __init__(self, output_size=1): 54 | super(AdaptiveAvgMaxPool2d, self).__init__() 55 | self.output_size = output_size 56 | 57 | def forward(self, x): 58 | return adaptive_avgmax_pool2d(x, self.output_size) 59 | 60 | 61 | class AdaptiveCatAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_catavgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class SelectAdaptivePool2d(nn.Module): 71 | """Selectable global pooling layer with dynamic input kernel size 72 | """ 73 | def __init__(self, output_size=1, pool_type='avg'): 74 | super(SelectAdaptivePool2d, self).__init__() 75 | self.output_size = output_size 76 | self.pool_type = pool_type 77 | if pool_type == 'avgmax': 78 | self.pool = AdaptiveAvgMaxPool2d(output_size) 79 | elif pool_type == 'catavgmax': 80 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 81 | elif pool_type == 'max': 82 | self.pool = nn.AdaptiveMaxPool2d(output_size) 83 | else: 84 | if pool_type != 'avg': 85 | assert False, 'Invalid pool type: %s' % pool_type 86 | self.pool = nn.AdaptiveAvgPool2d(output_size) 87 | 88 | def forward(self, x): 89 | return self.pool(x) 90 | 91 | def feat_mult(self): 92 | return adaptive_pool_feat_mult(self.pool_type) 93 | 94 | def __repr__(self): 95 | return self.__class__.__name__ + ' (' \ 96 | + 'output_size=' + str(self.output_size) \ 97 | + ', pool_type=' + self.pool_type + ')' 98 | -------------------------------------------------------------------------------- /timm/models/layers/avg_pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List 9 | import math 10 | 11 | from .helpers import tup_pair 12 | from .padding import pad_same 13 | 14 | 15 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 16 | ceil_mode: bool = False, count_include_pad: bool = True): 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = tup_pair(kernel_size) 26 | stride = tup_pair(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | return avg_pool2d_same( 31 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 32 | -------------------------------------------------------------------------------- /timm/models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by Ross Wightman 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | from .conv_bn_act import ConvBnAct 14 | 15 | 16 | class ChannelAttn(nn.Module): 17 | """ Original CBAM channel attention module, currently avg + max pool variant only. 18 | """ 19 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 20 | super(ChannelAttn, self).__init__() 21 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.max_pool = nn.AdaptiveMaxPool2d(1) 23 | self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) 24 | self.act = act_layer(inplace=True) 25 | self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) 26 | 27 | def forward(self, x): 28 | x_avg = self.avg_pool(x) 29 | x_max = self.max_pool(x) 30 | x_avg = self.fc2(self.act(self.fc1(x_avg))) 31 | x_max = self.fc2(self.act(self.fc1(x_max))) 32 | x_attn = x_avg + x_max 33 | return x * x_attn.sigmoid() 34 | 35 | 36 | class LightChannelAttn(ChannelAttn): 37 | """An experimental 'lightweight' that sums avg + max pool first 38 | """ 39 | def __init__(self, channels, reduction=16): 40 | super(LightChannelAttn, self).__init__(channels, reduction) 41 | 42 | def forward(self, x): 43 | x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) 44 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 45 | return x * x_attn.sigmoid() 46 | 47 | 48 | class SpatialAttn(nn.Module): 49 | """ Original CBAM spatial attention module 50 | """ 51 | def __init__(self, kernel_size=7): 52 | super(SpatialAttn, self).__init__() 53 | self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) 54 | 55 | def forward(self, x): 56 | x_avg = torch.mean(x, dim=1, keepdim=True) 57 | x_max = torch.max(x, dim=1, keepdim=True)[0] 58 | x_attn = torch.cat([x_avg, x_max], dim=1) 59 | x_attn = self.conv(x_attn) 60 | return x * x_attn.sigmoid() 61 | 62 | 63 | class LightSpatialAttn(nn.Module): 64 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 65 | """ 66 | def __init__(self, kernel_size=7): 67 | super(LightSpatialAttn, self).__init__() 68 | self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) 69 | 70 | def forward(self, x): 71 | x_avg = torch.mean(x, dim=1, keepdim=True) 72 | x_max = torch.max(x, dim=1, keepdim=True)[0] 73 | x_attn = 0.5 * x_avg + 0.5 * x_max 74 | x_attn = self.conv(x_attn) 75 | return x * x_attn.sigmoid() 76 | 77 | 78 | class CbamModule(nn.Module): 79 | def __init__(self, channels, spatial_kernel_size=7): 80 | super(CbamModule, self).__init__() 81 | self.channel = ChannelAttn(channels) 82 | self.spatial = SpatialAttn(spatial_kernel_size) 83 | 84 | def forward(self, x): 85 | x = self.channel(x) 86 | x = self.spatial(x) 87 | return x 88 | 89 | 90 | class LightCbamModule(nn.Module): 91 | def __init__(self, channels, spatial_kernel_size=7): 92 | super(LightCbamModule, self).__init__() 93 | self.channel = LightChannelAttn(channels) 94 | self.spatial = LightSpatialAttn(spatial_kernel_size) 95 | 96 | def forward(self, x): 97 | x = self.channel(x) 98 | x = self.spatial(x) 99 | return x 100 | 101 | -------------------------------------------------------------------------------- /timm/models/layers/cond_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Conditionally Parameterized Convolution (CondConv) 2 | 3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference 4 | (https://arxiv.org/abs/1904.04971) 5 | 6 | Hacked together by Ross Wightman 7 | """ 8 | 9 | import math 10 | from functools import partial 11 | import numpy as np 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | from .helpers import tup_pair 17 | from .conv2d_same import get_padding_value, conv2d_same 18 | 19 | 20 | def get_condconv_initializer(initializer, num_experts, expert_shape): 21 | def condconv_initializer(weight): 22 | """CondConv initializer function.""" 23 | num_params = np.prod(expert_shape) 24 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 25 | weight.shape[1] != num_params): 26 | raise (ValueError( 27 | 'CondConv variables must have shape [num_experts, num_params]')) 28 | for i in range(num_experts): 29 | initializer(weight[i].view(expert_shape)) 30 | return condconv_initializer 31 | 32 | 33 | class CondConv2d(nn.Module): 34 | """ Conditionally Parameterized Convolution 35 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 36 | 37 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 38 | https://github.com/pytorch/pytorch/issues/17983 39 | """ 40 | __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] 41 | 42 | def __init__(self, in_channels, out_channels, kernel_size=3, 43 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 44 | super(CondConv2d, self).__init__() 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.kernel_size = tup_pair(kernel_size) 49 | self.stride = tup_pair(stride) 50 | padding_val, is_padding_dynamic = get_padding_value( 51 | padding, kernel_size, stride=stride, dilation=dilation) 52 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 53 | self.padding = tup_pair(padding_val) 54 | self.dilation = tup_pair(dilation) 55 | self.groups = groups 56 | self.num_experts = num_experts 57 | 58 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 59 | weight_num_param = 1 60 | for wd in self.weight_shape: 61 | weight_num_param *= wd 62 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 63 | 64 | if bias: 65 | self.bias_shape = (self.out_channels,) 66 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 67 | else: 68 | self.register_parameter('bias', None) 69 | 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | init_weight = get_condconv_initializer( 74 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 75 | init_weight(self.weight) 76 | if self.bias is not None: 77 | fan_in = np.prod(self.weight_shape[1:]) 78 | bound = 1 / math.sqrt(fan_in) 79 | init_bias = get_condconv_initializer( 80 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 81 | init_bias(self.bias) 82 | 83 | def forward(self, x, routing_weights): 84 | B, C, H, W = x.shape 85 | weight = torch.matmul(routing_weights, self.weight) 86 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 87 | weight = weight.view(new_weight_shape) 88 | bias = None 89 | if self.bias is not None: 90 | bias = torch.matmul(routing_weights, self.bias) 91 | bias = bias.view(B * self.out_channels) 92 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 93 | x = x.view(1, B * C, H, W) 94 | if self.dynamic_padding: 95 | out = conv2d_same( 96 | x, weight, bias, stride=self.stride, padding=self.padding, 97 | dilation=self.dilation, groups=self.groups * B) 98 | else: 99 | out = F.conv2d( 100 | x, weight, bias, stride=self.stride, padding=self.padding, 101 | dilation=self.dilation, groups=self.groups * B) 102 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 103 | 104 | # Literal port (from TF definition) 105 | # x = torch.split(x, 1, 0) 106 | # weight = torch.split(weight, 1, 0) 107 | # if self.bias is not None: 108 | # bias = torch.matmul(routing_weights, self.bias) 109 | # bias = torch.split(bias, 1, 0) 110 | # else: 111 | # bias = [None] * B 112 | # out = [] 113 | # for xi, wi, bi in zip(x, weight, bias): 114 | # wi = wi.view(*self.weight_shape) 115 | # if bi is not None: 116 | # bi = bi.view(*self.bias_shape) 117 | # out.append(self.conv_fn( 118 | # xi, wi, bi, stride=self.stride, padding=self.padding, 119 | # dilation=self.dilation, groups=self.groups)) 120 | # out = torch.cat(out, 0) 121 | return out 122 | -------------------------------------------------------------------------------- /timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Union, List, Tuple, Optional, Callable 9 | import math 10 | 11 | from .padding import get_padding, pad_same, is_static_pad 12 | 13 | 14 | def conv2d_same( 15 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 16 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 17 | x = pad_same(x, weight.shape[-2:], stride, dilation) 18 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 19 | 20 | 21 | class Conv2dSame(nn.Conv2d): 22 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 26 | padding=0, dilation=1, groups=1, bias=True): 27 | super(Conv2dSame, self).__init__( 28 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 29 | 30 | def forward(self, x): 31 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 32 | 33 | 34 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 35 | dynamic = False 36 | if isinstance(padding, str): 37 | # for any string padding, the padding will be calculated for you, one of three ways 38 | padding = padding.lower() 39 | if padding == 'same': 40 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 41 | if is_static_pad(kernel_size, **kwargs): 42 | # static case, no extra overhead 43 | padding = get_padding(kernel_size, **kwargs) 44 | else: 45 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 46 | padding = 0 47 | dynamic = True 48 | elif padding == 'valid': 49 | # 'VALID' padding, same as padding=0 50 | padding = 0 51 | else: 52 | # Default to PyTorch style 'same'-ish symmetric padding 53 | padding = get_padding(kernel_size, **kwargs) 54 | return padding, dynamic 55 | 56 | 57 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 58 | padding = kwargs.pop('padding', '') 59 | kwargs.setdefault('bias', False) 60 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 61 | if is_dynamic: 62 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 63 | else: 64 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 65 | 66 | 67 | -------------------------------------------------------------------------------- /timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from timm.models.layers import get_padding 8 | 9 | 10 | class ConvBnAct(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, 12 | drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 13 | super(ConvBnAct, self).__init__() 14 | padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block 15 | self.conv = nn.Conv2d( 16 | in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 17 | padding=padding, dilation=dilation, groups=groups, bias=False) 18 | self.bn = norm_layer(out_channels) 19 | self.drop_block = drop_block 20 | if act_layer is not None: 21 | self.act = act_layer(inplace=True) 22 | else: 23 | self.act = None 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | x = self.bn(x) 28 | if self.drop_block is not None: 29 | x = self.drop_block(x) 30 | if self.act is not None: 31 | x = self.act(x) 32 | return x 33 | -------------------------------------------------------------------------------- /timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Select AttentionFactory Method 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | from .se import SEModule 7 | from .eca import EcaModule, CecaModule 8 | from .cbam import CbamModule, LightCbamModule 9 | 10 | 11 | def create_attn(attn_type, channels, **kwargs): 12 | module_cls = None 13 | if attn_type is not None: 14 | if isinstance(attn_type, str): 15 | attn_type = attn_type.lower() 16 | if attn_type == 'se': 17 | module_cls = SEModule 18 | elif attn_type == 'eca': 19 | module_cls = EcaModule 20 | elif attn_type == 'eca': 21 | module_cls = CecaModule 22 | elif attn_type == 'cbam': 23 | module_cls = CbamModule 24 | elif attn_type == 'lcbam': 25 | module_cls = LightCbamModule 26 | else: 27 | assert False, "Invalid attn module (%s)" % attn_type 28 | elif isinstance(attn_type, bool): 29 | if attn_type: 30 | module_cls = SEModule 31 | else: 32 | module_cls = attn_type 33 | if module_cls is not None: 34 | return module_cls(channels, **kwargs) 35 | return None 36 | -------------------------------------------------------------------------------- /timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_chs, out_chs, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | assert 'groups' not in kwargs # only use 'depthwise' bool arg 18 | if isinstance(kernel_size, list): 19 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | groups = out_chs if depthwise else 1 26 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 27 | m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) 28 | else: 29 | m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) 30 | return m 31 | -------------------------------------------------------------------------------- /timm/models/layers/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | 3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 4 | 5 | Papers: 6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 7 | 8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 9 | 10 | Code: 11 | DropBlock impl inspired by two Tensorflow impl that I liked: 12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 14 | 15 | Hacked together by Ross Wightman 16 | """ 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import numpy as np 21 | import math 22 | 23 | 24 | def drop_block_2d( 25 | x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7, 26 | gamma_scale: float = 1.0, drop_with_noise: bool = False): 27 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 28 | 29 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 30 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 31 | 32 | """ 33 | if drop_prob == 0. or not training: 34 | return x 35 | _, _, height, width = x.shape 36 | total_size = width * height 37 | clipped_block_size = min(block_size, min(width, height)) 38 | # seed_drop_rate, the gamma parameter 39 | seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 40 | (width - block_size + 1) * 41 | (height - block_size + 1)) 42 | 43 | # Forces the block to be inside the feature map. 44 | w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) 45 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ 46 | ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) 47 | valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() 48 | 49 | uniform_noise = torch.rand_like(x, dtype=torch.float32) 50 | block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() 51 | block_mask = -F.max_pool2d( 52 | -block_mask, 53 | kernel_size=clipped_block_size, # block_size, ??? 54 | stride=1, 55 | padding=clipped_block_size // 2) 56 | 57 | if drop_with_noise: 58 | normal_noise = torch.randn_like(x) 59 | x = x * block_mask + normal_noise * (1 - block_mask) 60 | else: 61 | normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) 62 | x = x * block_mask * normalize_scale 63 | return x 64 | 65 | 66 | class DropBlock2d(nn.Module): 67 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 68 | """ 69 | def __init__(self, 70 | drop_prob=0.1, 71 | block_size=7, 72 | gamma_scale=1.0, 73 | with_noise=False): 74 | super(DropBlock2d, self).__init__() 75 | self.drop_prob = drop_prob 76 | self.gamma_scale = gamma_scale 77 | self.block_size = block_size 78 | self.with_noise = with_noise 79 | 80 | def forward(self, x): 81 | return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise) 82 | 83 | 84 | def drop_path(x, drop_prob: float = 0., training: bool = False): 85 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 86 | 87 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 88 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 89 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 90 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 91 | 'survival rate' as the argument. 92 | 93 | """ 94 | if drop_prob == 0. or not training: 95 | return x 96 | keep_prob = 1 - drop_prob 97 | random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) 98 | random_tensor.floor_() # binarize 99 | output = x.div(keep_prob) * random_tensor 100 | return output 101 | 102 | 103 | class DropPath(nn.ModuleDict): 104 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 105 | """ 106 | def __init__(self, drop_prob=None): 107 | super(DropPath, self).__init__() 108 | self.drop_prob = drop_prob 109 | 110 | def forward(self, x): 111 | return drop_path(x, self.drop_prob, self.training) 112 | -------------------------------------------------------------------------------- /timm/models/layers/eca.py: -------------------------------------------------------------------------------- 1 | """ 2 | ECA module from ECAnet 3 | 4 | paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks 5 | https://arxiv.org/abs/1910.03151 6 | 7 | Original ECA model borrowed from https://github.com/BangguWu/ECANet 8 | 9 | Modified circular ECA implementation and adaption for use in timm package 10 | by Chris Ha https://github.com/VRandme 11 | 12 | Original License: 13 | 14 | MIT License 15 | 16 | Copyright (c) 2019 BangguWu, Qilong Wang 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | """ 36 | import math 37 | from torch import nn 38 | import torch.nn.functional as F 39 | 40 | 41 | class EcaModule(nn.Module): 42 | """Constructs an ECA module. 43 | 44 | Args: 45 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 46 | for actual calculations according to channel. 47 | gamma, beta: when channel is given parameters of mapping function 48 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 49 | (default=None. if channel size not given, use k_size given for kernel size.) 50 | kernel_size: Adaptive selection of kernel size (default=3) 51 | """ 52 | def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): 53 | super(EcaModule, self).__init__() 54 | assert kernel_size % 2 == 1 55 | 56 | if channels is not None: 57 | t = int(abs(math.log(channels, 2) + beta) / gamma) 58 | kernel_size = max(t if t % 2 else t + 1, 3) 59 | 60 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 61 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 62 | 63 | def forward(self, x): 64 | # Feature descriptor on the global spatial information 65 | y = self.avg_pool(x) 66 | # Reshape for convolution 67 | y = y.view(x.shape[0], 1, -1) 68 | # Two different branches of ECA module 69 | y = self.conv(y) 70 | # Multi-scale information fusion 71 | y = y.view(x.shape[0], -1, 1, 1).sigmoid() 72 | return x * y.expand_as(x) 73 | 74 | 75 | class CecaModule(nn.Module): 76 | """Constructs a circular ECA module. 77 | 78 | ECA module where the conv uses circular padding rather than zero padding. 79 | Unlike the spatial dimension, the channels do not have inherent ordering nor 80 | locality. Although this module in essence, applies such an assumption, it is unnecessary 81 | to limit the channels on either "edge" from being circularly adapted to each other. 82 | This will fundamentally increase connectivity and possibly increase performance metrics 83 | (accuracy, robustness), without signficantly impacting resource metrics 84 | (parameter size, throughput,latency, etc) 85 | 86 | Args: 87 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 88 | for actual calculations according to channel. 89 | gamma, beta: when channel is given parameters of mapping function 90 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 91 | (default=None. if channel size not given, use k_size given for kernel size.) 92 | kernel_size: Adaptive selection of kernel size (default=3) 93 | """ 94 | 95 | def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): 96 | super(CecaModule, self).__init__() 97 | assert kernel_size % 2 == 1 98 | 99 | if channels is not None: 100 | t = int(abs(math.log(channels, 2) + beta) / gamma) 101 | kernel_size = max(t if t % 2 else t + 1, 3) 102 | 103 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 104 | #pytorch circular padding mode is buggy as of pytorch 1.4 105 | #see https://github.com/pytorch/pytorch/pull/17240 106 | 107 | #implement manual circular padding 108 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) 109 | self.padding = (kernel_size - 1) // 2 110 | 111 | def forward(self, x): 112 | # Feature descriptor on the global spatial information 113 | y = self.avg_pool(x) 114 | 115 | # Manually implement circular padding, F.pad does not seemed to be bugged 116 | y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') 117 | 118 | # Two different branches of ECA module 119 | y = self.conv(y) 120 | 121 | # Multi-scale information fusion 122 | y = y.view(x.shape[0], -1, 1, 1).sigmoid() 123 | 124 | return x * y.expand_as(x) 125 | -------------------------------------------------------------------------------- /timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | from itertools import repeat 6 | from torch._six import container_abcs 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, container_abcs.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | tup_single = _ntuple(1) 19 | tup_pair = _ntuple(2) 20 | tup_triple = _ntuple(3) 21 | tup_quadruple = _ntuple(4) 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.modules.utils import _pair, _quadruple 6 | 7 | 8 | class MedianPool2d(nn.Module): 9 | """ Median pool (usable as median filter when stride=1) module. 10 | 11 | Args: 12 | kernel_size: size of pooling kernel, int or 2-tuple 13 | stride: pool stride, int or 2-tuple 14 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 15 | same: override padding and enforce same padding, boolean 16 | """ 17 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 18 | super(MedianPool2d, self).__init__() 19 | self.k = _pair(kernel_size) 20 | self.stride = _pair(stride) 21 | self.padding = _quadruple(padding) # convert to l, r, t, b 22 | self.same = same 23 | 24 | def _padding(self, x): 25 | if self.same: 26 | ih, iw = x.size()[2:] 27 | if ih % self.stride[0] == 0: 28 | ph = max(self.k[0] - self.stride[0], 0) 29 | else: 30 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 31 | if iw % self.stride[1] == 0: 32 | pw = max(self.k[1] - self.stride[1], 0) 33 | else: 34 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 35 | pl = pw // 2 36 | pr = pw - pl 37 | pt = ph // 2 38 | pb = ph - pt 39 | padding = (pl, pr, pt, pb) 40 | else: 41 | padding = self.padding 42 | return padding 43 | 44 | def forward(self, x): 45 | x = F.pad(x, self._padding(x), mode='reflect') 46 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 47 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 48 | return x 49 | -------------------------------------------------------------------------------- /timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = out_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import math 6 | from typing import List 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 33 | return x 34 | -------------------------------------------------------------------------------- /timm/models/layers/se.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | 4 | class SEModule(nn.Module): 5 | 6 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 7 | super(SEModule, self).__init__() 8 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 9 | reduction_channels = max(channels // reduction, 8) 10 | self.fc1 = nn.Conv2d( 11 | channels, reduction_channels, kernel_size=1, padding=0, bias=True) 12 | self.act = act_layer(inplace=True) 13 | self.fc2 = nn.Conv2d( 14 | reduction_channels, channels, kernel_size=1, padding=0, bias=True) 15 | 16 | def forward(self, x): 17 | x_se = self.avg_pool(x) 18 | x_se = self.fc1(x_se) 19 | x_se = self.act(x_se) 20 | x_se = self.fc2(x_se) 21 | return x * x_se.sigmoid() 22 | -------------------------------------------------------------------------------- /timm/models/layers/selective_kernel.py: -------------------------------------------------------------------------------- 1 | """ Selective Kernel Convolution/Attention 2 | 3 | Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) 4 | 5 | Hacked together by Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv_bn_act import ConvBnAct 12 | 13 | 14 | def _kernel_valid(k): 15 | if isinstance(k, (list, tuple)): 16 | for ki in k: 17 | return _kernel_valid(ki) 18 | assert k >= 3 and k % 2 19 | 20 | 21 | class SelectiveKernelAttn(nn.Module): 22 | def __init__(self, channels, num_paths=2, attn_channels=32, 23 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 24 | """ Selective Kernel Attention Module 25 | 26 | Selective Kernel attention mechanism factored out into its own module. 27 | 28 | """ 29 | super(SelectiveKernelAttn, self).__init__() 30 | self.num_paths = num_paths 31 | self.pool = nn.AdaptiveAvgPool2d(1) 32 | self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) 33 | self.bn = norm_layer(attn_channels) 34 | self.act = act_layer(inplace=True) 35 | self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) 36 | 37 | def forward(self, x): 38 | assert x.shape[1] == self.num_paths 39 | x = torch.sum(x, dim=1) 40 | x = self.pool(x) 41 | x = self.fc_reduce(x) 42 | x = self.bn(x) 43 | x = self.act(x) 44 | x = self.fc_select(x) 45 | B, C, H, W = x.shape 46 | x = x.view(B, self.num_paths, C // self.num_paths, H, W) 47 | x = torch.softmax(x, dim=1) 48 | return x 49 | 50 | 51 | class SelectiveKernelConv(nn.Module): 52 | 53 | def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, 54 | attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, 55 | drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 56 | """ Selective Kernel Convolution Module 57 | 58 | As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. 59 | 60 | Largest change is the input split, which divides the input channels across each convolution path, this can 61 | be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps 62 | the parameter count from ballooning when the convolutions themselves don't have groups, but still provides 63 | a noteworthy increase in performance over similar param count models without this attention layer. -Ross W 64 | 65 | Args: 66 | in_channels (int): module input (feature) channel count 67 | out_channels (int): module output (feature) channel count 68 | kernel_size (int, list): kernel size for each convolution branch 69 | stride (int): stride for convolutions 70 | dilation (int): dilation for module as a whole, impacts dilation of each branch 71 | groups (int): number of groups for each branch 72 | attn_reduction (int, float): reduction factor for attention features 73 | min_attn_channels (int): minimum attention feature channels 74 | keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations 75 | split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, 76 | can be viewed as grouping by path, output expands to module out_channels count 77 | drop_block (nn.Module): drop block module 78 | act_layer (nn.Module): activation layer to use 79 | norm_layer (nn.Module): batchnorm/norm layer to use 80 | """ 81 | super(SelectiveKernelConv, self).__init__() 82 | kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation 83 | _kernel_valid(kernel_size) 84 | if not isinstance(kernel_size, list): 85 | kernel_size = [kernel_size] * 2 86 | if keep_3x3: 87 | dilation = [dilation * (k - 1) // 2 for k in kernel_size] 88 | kernel_size = [3] * len(kernel_size) 89 | else: 90 | dilation = [dilation] * len(kernel_size) 91 | self.num_paths = len(kernel_size) 92 | self.in_channels = in_channels 93 | self.out_channels = out_channels 94 | self.split_input = split_input 95 | if self.split_input: 96 | assert in_channels % self.num_paths == 0 97 | in_channels = in_channels // self.num_paths 98 | groups = min(out_channels, groups) 99 | 100 | conv_kwargs = dict( 101 | stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) 102 | self.paths = nn.ModuleList([ 103 | ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) 104 | for k, d in zip(kernel_size, dilation)]) 105 | 106 | attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) 107 | self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) 108 | self.drop_block = drop_block 109 | 110 | def forward(self, x): 111 | if self.split_input: 112 | x_split = torch.split(x, self.in_channels // self.num_paths, 1) 113 | x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] 114 | else: 115 | x_paths = [op(x) for op in self.paths] 116 | x = torch.stack(x_paths, dim=1) 117 | x_attn = self.attn(x) 118 | x = x * x_attn 119 | x = torch.sum(x, dim=1) 120 | return x 121 | -------------------------------------------------------------------------------- /timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 10 | 11 | 12 | class TestTimePoolHead(nn.Module): 13 | def __init__(self, base, original_pool=7): 14 | super(TestTimePoolHead, self).__init__() 15 | self.base = base 16 | self.original_pool = original_pool 17 | base_fc = self.base.get_classifier() 18 | if isinstance(base_fc, nn.Conv2d): 19 | self.fc = base_fc 20 | else: 21 | self.fc = nn.Conv2d( 22 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 23 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 24 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 25 | self.base.reset_classifier(0) # delete original fc layer 26 | 27 | def forward(self, x): 28 | x = self.base.forward_features(x) 29 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 30 | x = self.fc(x) 31 | x = adaptive_avgmax_pool2d(x, 1) 32 | return x.view(x.size(0), -1) 33 | 34 | 35 | def apply_test_time_pool(model, config, args): 36 | test_time_pool = False 37 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 38 | return model, False 39 | if not args.no_test_pool and \ 40 | config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ 41 | config['input_size'][-2] > model.default_cfg['input_size'][-2]: 42 | logging.info('Target input size %s > pretrained default %s, using test time pooling' % 43 | (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) 44 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 45 | test_time_pool = True 46 | return model, test_time_pool 47 | -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import fnmatch 4 | from collections import defaultdict 5 | 6 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] 7 | 8 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 9 | _model_to_module = {} # mapping of model names to module names 10 | _model_entrypoints = {} # mapping of model names to entrypoint fns 11 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 12 | 13 | 14 | def register_model(fn): 15 | # lookup containing module 16 | mod = sys.modules[fn.__module__] 17 | module_name_split = fn.__module__.split('.') 18 | module_name = module_name_split[-1] if len(module_name_split) else '' 19 | 20 | # add model to __all__ in module 21 | model_name = fn.__name__ 22 | if hasattr(mod, '__all__'): 23 | mod.__all__.append(model_name) 24 | else: 25 | mod.__all__ = [model_name] 26 | 27 | # add entries to registry dict/sets 28 | _model_entrypoints[model_name] = fn 29 | _model_to_module[model_name] = module_name 30 | _module_to_models[module_name].add(model_name) 31 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 32 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 33 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 34 | # entrypoints or non-matching combos 35 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 36 | if has_pretrained: 37 | _model_has_pretrained.add(model_name) 38 | return fn 39 | 40 | 41 | def _natural_key(string_): 42 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 43 | 44 | 45 | def list_models(filter='', module='', pretrained=False): 46 | """ Return list of available model names, sorted alphabetically 47 | 48 | Args: 49 | filter (str) - Wildcard filter string that works with fnmatch 50 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 51 | 52 | Example: 53 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 54 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 55 | """ 56 | if module: 57 | models = list(_module_to_models[module]) 58 | else: 59 | models = _model_entrypoints.keys() 60 | if filter: 61 | models = fnmatch.filter(models, filter) 62 | if pretrained: 63 | models = _model_has_pretrained.intersection(models) 64 | return list(sorted(models, key=_natural_key)) 65 | 66 | 67 | def is_model(model_name): 68 | """ Check if a model name exists 69 | """ 70 | return model_name in _model_entrypoints 71 | 72 | 73 | def model_entrypoint(model_name): 74 | """Fetch a model entrypoint for specified model name 75 | """ 76 | return _model_entrypoints[model_name] 77 | 78 | 79 | def list_modules(): 80 | """ Return list of module names that contain models / model entrypoints 81 | """ 82 | modules = _module_to_models.keys() 83 | return list(sorted(modules)) 84 | 85 | 86 | def is_model_in_modules(model_name, module_names): 87 | """Check if a model exists within a subset of modules 88 | Args: 89 | model_name (str) - name of model to check 90 | module_names (tuple, list, set) - names of modules to search in 91 | """ 92 | assert isinstance(module_names, (tuple, list, set)) 93 | return any(model_name in _module_to_models[n] for n in module_names) 94 | 95 | -------------------------------------------------------------------------------- /timm/models/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | 4 | @author: tstandley 5 | Adapted by cadene 6 | 7 | Creates an Xception Model as defined in: 8 | 9 | Francois Chollet 10 | Xception: Deep Learning with Depthwise Separable Convolutions 11 | https://arxiv.org/pdf/1610.02357.pdf 12 | 13 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 14 | 15 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 16 | 17 | REMEMBER to set your image size to 3x299x299 for both test and validation 18 | 19 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 20 | std=[0.5, 0.5, 0.5]) 21 | 22 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 23 | """ 24 | import math 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | from .registry import register_model 31 | from .helpers import load_pretrained 32 | from .layers import SelectAdaptivePool2d 33 | 34 | __all__ = ['Xception'] 35 | 36 | default_cfgs = { 37 | 'xception': { 38 | 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', 39 | 'input_size': (3, 299, 299), 40 | 'crop_pct': 0.8975, 41 | 'interpolation': 'bicubic', 42 | 'mean': (0.5, 0.5, 0.5), 43 | 'std': (0.5, 0.5, 0.5), 44 | 'num_classes': 1000, 45 | 'first_conv': 'conv1', 46 | 'classifier': 'fc' 47 | # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 48 | } 49 | } 50 | 51 | 52 | class SeparableConv2d(nn.Module): 53 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 54 | super(SeparableConv2d, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d( 57 | in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) 58 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.pointwise(x) 63 | return x 64 | 65 | 66 | class Block(nn.Module): 67 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 68 | super(Block, self).__init__() 69 | 70 | if out_filters != in_filters or strides != 1: 71 | self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) 72 | self.skipbn = nn.BatchNorm2d(out_filters) 73 | else: 74 | self.skip = None 75 | 76 | self.relu = nn.ReLU(inplace=True) 77 | rep = [] 78 | 79 | filters = in_filters 80 | if grow_first: 81 | rep.append(self.relu) 82 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 83 | rep.append(nn.BatchNorm2d(out_filters)) 84 | filters = out_filters 85 | 86 | for i in range(reps - 1): 87 | rep.append(self.relu) 88 | rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) 89 | rep.append(nn.BatchNorm2d(filters)) 90 | 91 | if not grow_first: 92 | rep.append(self.relu) 93 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 94 | rep.append(nn.BatchNorm2d(out_filters)) 95 | 96 | if not start_with_relu: 97 | rep = rep[1:] 98 | else: 99 | rep[0] = nn.ReLU(inplace=False) 100 | 101 | if strides != 1: 102 | rep.append(nn.MaxPool2d(3, strides, 1)) 103 | self.rep = nn.Sequential(*rep) 104 | 105 | def forward(self, inp): 106 | x = self.rep(inp) 107 | 108 | if self.skip is not None: 109 | skip = self.skip(inp) 110 | skip = self.skipbn(skip) 111 | else: 112 | skip = inp 113 | 114 | x += skip 115 | return x 116 | 117 | 118 | class Xception(nn.Module): 119 | """ 120 | Xception optimized for the ImageNet dataset, as specified in 121 | https://arxiv.org/pdf/1610.02357.pdf 122 | """ 123 | 124 | def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): 125 | """ Constructor 126 | Args: 127 | num_classes: number of classes 128 | """ 129 | super(Xception, self).__init__() 130 | self.drop_rate = drop_rate 131 | self.global_pool = global_pool 132 | self.num_classes = num_classes 133 | self.num_features = 2048 134 | 135 | self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) 136 | self.bn1 = nn.BatchNorm2d(32) 137 | self.relu = nn.ReLU(inplace=True) 138 | 139 | self.conv2 = nn.Conv2d(32, 64, 3, bias=False) 140 | self.bn2 = nn.BatchNorm2d(64) 141 | # do relu here 142 | 143 | self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) 144 | self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) 145 | self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) 146 | 147 | self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 148 | self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 149 | self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 150 | self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 151 | 152 | self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 153 | self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 154 | self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 155 | self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) 156 | 157 | self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) 158 | 159 | self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) 160 | self.bn3 = nn.BatchNorm2d(1536) 161 | 162 | # do relu here 163 | self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) 164 | self.bn4 = nn.BatchNorm2d(self.num_features) 165 | 166 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) 167 | self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) 168 | 169 | # #------- init weights -------- 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 173 | elif isinstance(m, nn.BatchNorm2d): 174 | m.weight.data.fill_(1) 175 | m.bias.data.zero_() 176 | 177 | def get_classifier(self): 178 | return self.fc 179 | 180 | def reset_classifier(self, num_classes, global_pool='avg'): 181 | self.num_classes = num_classes 182 | self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) 183 | del self.fc 184 | self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None 185 | 186 | def forward_features(self, x): 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | 191 | x = self.conv2(x) 192 | x = self.bn2(x) 193 | x = self.relu(x) 194 | 195 | x = self.block1(x) 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | x = self.block4(x) 199 | x = self.block5(x) 200 | x = self.block6(x) 201 | x = self.block7(x) 202 | x = self.block8(x) 203 | x = self.block9(x) 204 | x = self.block10(x) 205 | x = self.block11(x) 206 | x = self.block12(x) 207 | 208 | x = self.conv3(x) 209 | x = self.bn3(x) 210 | x = self.relu(x) 211 | 212 | x = self.conv4(x) 213 | x = self.bn4(x) 214 | x = self.relu(x) 215 | return x 216 | 217 | def forward(self, x): 218 | x = self.forward_features(x) 219 | x = self.global_pool(x).flatten(1) 220 | if self.drop_rate: 221 | F.dropout(x, self.drop_rate, training=self.training) 222 | x = self.fc(x) 223 | return x 224 | 225 | 226 | @register_model 227 | def xception(pretrained=False, num_classes=1000, in_chans=3, **kwargs): 228 | default_cfg = default_cfgs['xception'] 229 | model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs) 230 | model.default_cfg = default_cfg 231 | if pretrained: 232 | load_pretrained(model, default_cfg, num_classes, in_chans) 233 | 234 | return model 235 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .nadam import Nadam 2 | from .rmsprop_tf import RMSpropTF 3 | from .adamw import AdamW 4 | from .radam import RAdam 5 | from .novograd import NovoGrad 6 | from .nvnovograd import NvNovoGrad 7 | from .lookahead import Lookahead 8 | from .optim_factory import create_optimizer 9 | -------------------------------------------------------------------------------- /timm/optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | """ 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | from collections import defaultdict 8 | 9 | 10 | class Lookahead(Optimizer): 11 | def __init__(self, base_optimizer, alpha=0.5, k=6): 12 | if not 0.0 <= alpha <= 1.0: 13 | raise ValueError(f'Invalid slow update rate: {alpha}') 14 | if not 1 <= k: 15 | raise ValueError(f'Invalid lookahead steps: {k}') 16 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 17 | self.base_optimizer = base_optimizer 18 | self.param_groups = self.base_optimizer.param_groups 19 | self.defaults = base_optimizer.defaults 20 | self.defaults.update(defaults) 21 | self.state = defaultdict(dict) 22 | # manually add our defaults to the param groups 23 | for name, default in defaults.items(): 24 | for group in self.param_groups: 25 | group.setdefault(name, default) 26 | 27 | def update_slow(self, group): 28 | for fast_p in group["params"]: 29 | if fast_p.grad is None: 30 | continue 31 | param_state = self.state[fast_p] 32 | if 'slow_buffer' not in param_state: 33 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 34 | param_state['slow_buffer'].copy_(fast_p.data) 35 | slow = param_state['slow_buffer'] 36 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 37 | fast_p.data.copy_(slow) 38 | 39 | def sync_lookahead(self): 40 | for group in self.param_groups: 41 | self.update_slow(group) 42 | 43 | def step(self, closure=None): 44 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 45 | loss = self.base_optimizer.step(closure) 46 | for group in self.param_groups: 47 | group['lookahead_step'] += 1 48 | if group['lookahead_step'] % group['lookahead_k'] == 0: 49 | self.update_slow(group) 50 | return loss 51 | 52 | def state_dict(self): 53 | fast_state_dict = self.base_optimizer.state_dict() 54 | slow_state = { 55 | (id(k) if isinstance(k, torch.Tensor) else k): v 56 | for k, v in self.state.items() 57 | } 58 | fast_state = fast_state_dict['state'] 59 | param_groups = fast_state_dict['param_groups'] 60 | return { 61 | 'state': fast_state, 62 | 'slow_state': slow_state, 63 | 'param_groups': param_groups, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | fast_state_dict = { 68 | 'state': state_dict['state'], 69 | 'param_groups': state_dict['param_groups'], 70 | } 71 | self.base_optimizer.load_state_dict(fast_state_dict) 72 | 73 | # We want to restore the slow state, but share param_groups reference 74 | # with base_optimizer. This is a bit redundant but least code 75 | slow_state_new = False 76 | if 'slow_state' not in state_dict: 77 | print('Loading state_dict from optimizer without Lookahead applied.') 78 | state_dict['slow_state'] = defaultdict(dict) 79 | slow_state_new = True 80 | slow_state_dict = { 81 | 'state': state_dict['slow_state'], 82 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 83 | } 84 | super(Lookahead, self).load_state_dict(slow_state_dict) 85 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 86 | if slow_state_new: 87 | # reapply defaults to catch missing lookahead specific ones 88 | for name, default in self.defaults.items(): 89 | for group in self.param_groups: 90 | group.setdefault(name, default) 91 | -------------------------------------------------------------------------------- /timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /timm/optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /timm/optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /timm/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead 4 | try: 5 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 6 | has_apex = True 7 | except ImportError: 8 | has_apex = False 9 | 10 | 11 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 12 | decay = [] 13 | no_decay = [] 14 | for name, param in model.named_parameters(): 15 | if not param.requires_grad: 16 | continue # frozen weights 17 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 18 | no_decay.append(param) 19 | else: 20 | decay.append(param) 21 | return [ 22 | {'params': no_decay, 'weight_decay': 0.}, 23 | {'params': decay, 'weight_decay': weight_decay}] 24 | 25 | 26 | def create_optimizer(args, model, filter_bias_and_bn=True): 27 | opt_lower = args.opt.lower() 28 | weight_decay = args.weight_decay 29 | if 'adamw' in opt_lower or 'radam' in opt_lower: 30 | # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay 31 | # I don't believe they follow the paper or original Torch7 impl which schedules weight 32 | # decay based on the ratio of current_lr/initial_lr 33 | weight_decay /= args.lr 34 | if weight_decay and filter_bias_and_bn: 35 | parameters = add_weight_decay(model, weight_decay) 36 | weight_decay = 0. 37 | else: 38 | parameters = model.parameters() 39 | 40 | if 'fused' in opt_lower: 41 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 42 | 43 | opt_split = opt_lower.split('_') 44 | opt_lower = opt_split[-1] 45 | if opt_lower == 'sgd': 46 | optimizer = optim.SGD( 47 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) 48 | elif opt_lower == 'adam': 49 | optimizer = optim.Adam( 50 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 51 | elif opt_lower == 'adamw': 52 | optimizer = AdamW( 53 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 54 | elif opt_lower == 'nadam': 55 | optimizer = Nadam( 56 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 57 | elif opt_lower == 'radam': 58 | optimizer = RAdam( 59 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 60 | elif opt_lower == 'adadelta': 61 | optimizer = optim.Adadelta( 62 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 63 | elif opt_lower == 'rmsprop': 64 | optimizer = optim.RMSprop( 65 | parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, 66 | momentum=args.momentum, weight_decay=weight_decay) 67 | elif opt_lower == 'rmsproptf': 68 | optimizer = RMSpropTF( 69 | parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, 70 | momentum=args.momentum, weight_decay=weight_decay) 71 | elif opt_lower == 'novograd': 72 | optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 73 | elif opt_lower == 'nvnovograd': 74 | optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 75 | elif opt_lower == 'fusedsgd': 76 | optimizer = FusedSGD( 77 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) 78 | elif opt_lower == 'fusedadam': 79 | optimizer = FusedAdam( 80 | parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) 81 | elif opt_lower == 'fusedadamw': 82 | optimizer = FusedAdam( 83 | parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) 84 | elif opt_lower == 'fusedlamb': 85 | optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 86 | elif opt_lower == 'fusednovograd': 87 | optimizer = FusedNovoGrad( 88 | parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) 89 | else: 90 | assert False and "Invalid optimizer" 91 | raise ValueError 92 | 93 | if len(opt_split) > 1: 94 | if opt_split[0] == 'lookahead': 95 | optimizer = Lookahead(optimizer) 96 | 97 | return optimizer 98 | -------------------------------------------------------------------------------- /timm/optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | 10 | class RAdam(Optimizer): 11 | 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(RAdam, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(RAdam, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('RAdam does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | 53 | state['step'] += 1 54 | buffered = self.buffer[int(state['step'] % 10)] 55 | if state['step'] == buffered[0]: 56 | N_sma, step_size = buffered[1], buffered[2] 57 | else: 58 | buffered[0] = state['step'] 59 | beta2_t = beta2 ** state['step'] 60 | N_sma_max = 2 / (1 - beta2) - 1 61 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 62 | buffered[1] = N_sma 63 | 64 | # more conservative since it's an approximated value 65 | if N_sma >= 5: 66 | step_size = group['lr'] * math.sqrt( 67 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 68 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 69 | else: 70 | step_size = group['lr'] / (1 - beta1 ** state['step']) 71 | buffered[2] = step_size 72 | 73 | if group['weight_decay'] != 0: 74 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | denom = exp_avg_sq.sqrt().add_(group['eps']) 79 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 80 | else: 81 | p_data_fp32.add_(-step_size, exp_avg) 82 | 83 | p.data.copy_(p_data_fp32) 84 | 85 | return loss 86 | 87 | 88 | class PlainRAdam(Optimizer): 89 | 90 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 92 | 93 | super(PlainRAdam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(PlainRAdam, self).__setstate__(state) 97 | 98 | def step(self, closure=None): 99 | 100 | loss = None 101 | if closure is not None: 102 | loss = closure() 103 | 104 | for group in self.param_groups: 105 | 106 | for p in group['params']: 107 | if p.grad is None: 108 | continue 109 | grad = p.grad.data.float() 110 | if grad.is_sparse: 111 | raise RuntimeError('RAdam does not support sparse gradients') 112 | 113 | p_data_fp32 = p.data.float() 114 | 115 | state = self.state[p] 116 | 117 | if len(state) == 0: 118 | state['step'] = 0 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 129 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 130 | 131 | state['step'] += 1 132 | beta2_t = beta2 ** state['step'] 133 | N_sma_max = 2 / (1 - beta2) - 1 134 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 135 | 136 | if group['weight_decay'] != 0: 137 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 138 | 139 | # more conservative since it's an approximated value 140 | if N_sma >= 5: 141 | step_size = group['lr'] * math.sqrt( 142 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 143 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 144 | denom = exp_avg_sq.sqrt().add_(group['eps']) 145 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 146 | else: 147 | step_size = group['lr'] / (1 - beta1 ** state['step']) 148 | p_data_fp32.add_(-step_size, exp_avg) 149 | 150 | p.data.copy_(p_data_fp32) 151 | 152 | return loss 153 | -------------------------------------------------------------------------------- /timm/optim/rmsprop_tf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class RMSpropTF(Optimizer): 6 | """Implements RMSprop algorithm (TensorFlow style epsilon) 7 | 8 | NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt 9 | to closer match Tensorflow for matching hyper-params. 10 | 11 | Proposed by G. Hinton in his 12 | `course `_. 13 | 14 | The centered version first appears in `Generating Sequences 15 | With Recurrent Neural Networks `_. 16 | 17 | Arguments: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-2) 21 | momentum (float, optional): momentum factor (default: 0) 22 | alpha (float, optional): smoothing (decay) constant (default: 0.9) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-10) 25 | centered (bool, optional) : if ``True``, compute the centered RMSProp, 26 | the gradient is normalized by an estimation of its variance 27 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 28 | decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 29 | lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer 30 | update as per defaults in Tensorflow 31 | 32 | """ 33 | 34 | def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, 35 | decoupled_decay=False, lr_in_momentum=True): 36 | if not 0.0 <= lr: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | if not 0.0 <= eps: 39 | raise ValueError("Invalid epsilon value: {}".format(eps)) 40 | if not 0.0 <= momentum: 41 | raise ValueError("Invalid momentum value: {}".format(momentum)) 42 | if not 0.0 <= weight_decay: 43 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 44 | if not 0.0 <= alpha: 45 | raise ValueError("Invalid alpha value: {}".format(alpha)) 46 | 47 | defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, 48 | decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) 49 | super(RMSpropTF, self).__init__(params, defaults) 50 | 51 | def __setstate__(self, state): 52 | super(RMSpropTF, self).__setstate__(state) 53 | for group in self.param_groups: 54 | group.setdefault('momentum', 0) 55 | group.setdefault('centered', False) 56 | 57 | def step(self, closure=None): 58 | """Performs a single optimization step. 59 | 60 | Arguments: 61 | closure (callable, optional): A closure that reevaluates the model 62 | and returns the loss. 63 | """ 64 | loss = None 65 | if closure is not None: 66 | loss = closure() 67 | 68 | for group in self.param_groups: 69 | for p in group['params']: 70 | if p.grad is None: 71 | continue 72 | grad = p.grad.data 73 | if grad.is_sparse: 74 | raise RuntimeError('RMSprop does not support sparse gradients') 75 | state = self.state[p] 76 | 77 | # State initialization 78 | if len(state) == 0: 79 | state['step'] = 0 80 | state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero 81 | if group['momentum'] > 0: 82 | state['momentum_buffer'] = torch.zeros_like(p.data) 83 | if group['centered']: 84 | state['grad_avg'] = torch.zeros_like(p.data) 85 | 86 | square_avg = state['square_avg'] 87 | one_minus_alpha = 1. - group['alpha'] 88 | 89 | state['step'] += 1 90 | 91 | if group['weight_decay'] != 0: 92 | if 'decoupled_decay' in group and group['decoupled_decay']: 93 | p.data.add_(-group['weight_decay'], p.data) 94 | else: 95 | grad = grad.add(group['weight_decay'], p.data) 96 | 97 | # Tensorflow order of ops for updating squared avg 98 | square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) 99 | # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original 100 | 101 | if group['centered']: 102 | grad_avg = state['grad_avg'] 103 | grad_avg.add_(one_minus_alpha, grad - grad_avg) 104 | # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original 105 | avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt 106 | else: 107 | avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt 108 | 109 | if group['momentum'] > 0: 110 | buf = state['momentum_buffer'] 111 | # Tensorflow accumulates the LR scaling in the momentum buffer 112 | if 'lr_in_momentum' in group and group['lr_in_momentum']: 113 | buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) 114 | p.data.add_(-buf) 115 | else: 116 | # PyTorch scales the param update by LR 117 | buf.mul_(group['momentum']).addcdiv_(grad, avg) 118 | p.data.add_(-group['lr'], buf) 119 | else: 120 | p.data.addcdiv_(-group['lr'], grad, avg) 121 | 122 | return loss 123 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /timm/scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from .scheduler import Scheduler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class CosineLRScheduler(Scheduler): 13 | """ 14 | Cosine decay with restarts. 15 | This is described in the paper https://arxiv.org/abs/1608.03983. 16 | 17 | Inspiration from 18 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 19 | """ 20 | 21 | def __init__(self, 22 | optimizer: torch.optim.Optimizer, 23 | t_initial: int, 24 | t_mul: float = 1., 25 | lr_min: float = 0., 26 | decay_rate: float = 1., 27 | warmup_t=0, 28 | warmup_lr_init=0, 29 | warmup_prefix=False, 30 | cycle_limit=0, 31 | t_in_epochs=True, 32 | noise_range_t=None, 33 | noise_pct=0.67, 34 | noise_std=1.0, 35 | noise_seed=42, 36 | initialize=True) -> None: 37 | super().__init__( 38 | optimizer, param_group_field="lr", 39 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 40 | initialize=initialize) 41 | 42 | assert t_initial > 0 43 | assert lr_min >= 0 44 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 45 | logger.warning("Cosine annealing scheduler will have no effect on the learning " 46 | "rate since t_initial = t_mul = eta_mul = 1.") 47 | self.t_initial = t_initial 48 | self.t_mul = t_mul 49 | self.lr_min = lr_min 50 | self.decay_rate = decay_rate 51 | self.cycle_limit = cycle_limit 52 | self.warmup_t = warmup_t 53 | self.warmup_lr_init = warmup_lr_init 54 | self.warmup_prefix = warmup_prefix 55 | self.t_in_epochs = t_in_epochs 56 | if self.warmup_t: 57 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 58 | super().update_groups(self.warmup_lr_init) 59 | else: 60 | self.warmup_steps = [1 for _ in self.base_values] 61 | 62 | def _get_lr(self, t): 63 | if t < self.warmup_t: 64 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 65 | else: 66 | if self.warmup_prefix: 67 | t = t - self.warmup_t 68 | 69 | if self.t_mul != 1: 70 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 71 | t_i = self.t_mul ** i * self.t_initial 72 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 73 | else: 74 | i = t // self.t_initial 75 | t_i = self.t_initial 76 | t_curr = t - (self.t_initial * i) 77 | 78 | gamma = self.decay_rate ** i 79 | lr_min = self.lr_min * gamma 80 | lr_max_values = [v * gamma for v in self.base_values] 81 | 82 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 83 | lrs = [ 84 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 85 | ] 86 | else: 87 | lrs = [self.lr_min for _ in self.base_values] 88 | 89 | return lrs 90 | 91 | def get_epoch_values(self, epoch: int): 92 | if self.t_in_epochs: 93 | return self._get_lr(epoch) 94 | else: 95 | return None 96 | 97 | def get_update_values(self, num_updates: int): 98 | if not self.t_in_epochs: 99 | return self._get_lr(num_updates) 100 | else: 101 | return None 102 | 103 | def get_cycle_length(self, cycles=0): 104 | if not cycles: 105 | cycles = self.cycle_limit 106 | assert cycles > 0 107 | if self.t_mul == 1.0: 108 | return self.t_initial * cycles 109 | else: 110 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 111 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .scheduler import Scheduler 4 | 5 | 6 | class PlateauLRScheduler(Scheduler): 7 | """Decay the LR by a factor every time the validation loss plateaus.""" 8 | 9 | def __init__(self, 10 | optimizer, 11 | decay_rate=0.1, 12 | patience_t=10, 13 | verbose=True, 14 | threshold=1e-4, 15 | cooldown_t=0, 16 | warmup_t=0, 17 | warmup_lr_init=0, 18 | lr_min=0, 19 | mode='min', 20 | initialize=True, 21 | ): 22 | super().__init__(optimizer, 'lr', initialize=initialize) 23 | 24 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 25 | self.optimizer, 26 | patience=patience_t, 27 | factor=decay_rate, 28 | verbose=verbose, 29 | threshold=threshold, 30 | cooldown=cooldown_t, 31 | mode=mode, 32 | min_lr=lr_min 33 | ) 34 | 35 | self.warmup_t = warmup_t 36 | self.warmup_lr_init = warmup_lr_init 37 | if self.warmup_t: 38 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 39 | super().update_groups(self.warmup_lr_init) 40 | else: 41 | self.warmup_steps = [1 for _ in self.base_values] 42 | 43 | def state_dict(self): 44 | return { 45 | 'best': self.lr_scheduler.best, 46 | 'last_epoch': self.lr_scheduler.last_epoch, 47 | } 48 | 49 | def load_state_dict(self, state_dict): 50 | self.lr_scheduler.best = state_dict['best'] 51 | if 'last_epoch' in state_dict: 52 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 53 | 54 | # override the base class step fn completely 55 | def step(self, epoch, metric=None): 56 | if epoch <= self.warmup_t: 57 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 58 | super().update_groups(lrs) 59 | else: 60 | self.lr_scheduler.step(metric, epoch) 61 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .tanh_lr import TanhLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .plateau_lr import PlateauLRScheduler 5 | 6 | 7 | def create_scheduler(args, optimizer): 8 | num_epochs = args.epochs 9 | 10 | if args.lr_noise is not None: 11 | if isinstance(args.lr_noise, (list, tuple)): 12 | noise_range = [n * num_epochs for n in args.lr_noise] 13 | if len(noise_range) == 1: 14 | noise_range = noise_range[0] 15 | else: 16 | noise_range = args.lr_noise * num_epochs 17 | else: 18 | noise_range = None 19 | 20 | lr_scheduler = None 21 | #FIXME expose cycle parms of the scheduler config to arguments 22 | if args.sched == 'cosine': 23 | lr_scheduler = CosineLRScheduler( 24 | optimizer, 25 | t_initial=num_epochs, 26 | t_mul=1.0, 27 | lr_min=args.min_lr, 28 | decay_rate=args.decay_rate, 29 | warmup_lr_init=args.warmup_lr, 30 | warmup_t=args.warmup_epochs, 31 | cycle_limit=1, 32 | t_in_epochs=True, 33 | noise_range_t=noise_range, 34 | noise_pct=args.lr_noise_pct, 35 | noise_std=args.lr_noise_std, 36 | noise_seed=args.seed, 37 | ) 38 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 39 | elif args.sched == 'tanh': 40 | lr_scheduler = TanhLRScheduler( 41 | optimizer, 42 | t_initial=num_epochs, 43 | t_mul=1.0, 44 | lr_min=args.min_lr, 45 | warmup_lr_init=args.warmup_lr, 46 | warmup_t=args.warmup_epochs, 47 | cycle_limit=1, 48 | t_in_epochs=True, 49 | noise_range_t=noise_range, 50 | noise_pct=args.lr_noise_pct, 51 | noise_std=args.lr_noise_std, 52 | noise_seed=args.seed, 53 | ) 54 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 55 | elif args.sched == 'step': 56 | lr_scheduler = StepLRScheduler( 57 | optimizer, 58 | decay_t=args.decay_epochs, 59 | decay_rate=args.decay_rate, 60 | warmup_lr_init=args.warmup_lr, 61 | warmup_t=args.warmup_epochs, 62 | noise_range_t=noise_range, 63 | noise_pct=args.lr_noise_pct, 64 | noise_std=args.lr_noise_std, 65 | noise_seed=args.seed, 66 | ) 67 | elif args.sched == 'plateau': 68 | lr_scheduler = PlateauLRScheduler( 69 | optimizer, 70 | decay_rate=args.decay_rate, 71 | patience_t=args.patience_epochs, 72 | lr_min=args.min_lr, 73 | warmup_lr_init=args.warmup_lr, 74 | warmup_t=args.warmup_epochs, 75 | cooldown_t=args.cooldown_epochs, 76 | ) 77 | 78 | return lr_scheduler, num_epochs 79 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from .scheduler import Scheduler 5 | 6 | 7 | class StepLRScheduler(Scheduler): 8 | """ 9 | """ 10 | 11 | def __init__(self, 12 | optimizer: torch.optim.Optimizer, 13 | decay_t: float, 14 | decay_rate: float = 1., 15 | warmup_t=0, 16 | warmup_lr_init=0, 17 | t_in_epochs=True, 18 | noise_range_t=None, 19 | noise_pct=0.67, 20 | noise_std=1.0, 21 | noise_seed=42, 22 | initialize=True, 23 | ) -> None: 24 | super().__init__( 25 | optimizer, param_group_field="lr", 26 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 27 | initialize=initialize) 28 | 29 | self.decay_t = decay_t 30 | self.decay_rate = decay_rate 31 | self.warmup_t = warmup_t 32 | self.warmup_lr_init = warmup_lr_init 33 | self.t_in_epochs = t_in_epochs 34 | if self.warmup_t: 35 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 36 | super().update_groups(self.warmup_lr_init) 37 | else: 38 | self.warmup_steps = [1 for _ in self.base_values] 39 | 40 | def _get_lr(self, t): 41 | if t < self.warmup_t: 42 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 43 | else: 44 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 45 | return lrs 46 | 47 | def get_epoch_values(self, epoch: int): 48 | if self.t_in_epochs: 49 | return self._get_lr(epoch) 50 | else: 51 | return None 52 | 53 | def get_update_values(self, num_updates: int): 54 | if not self.t_in_epochs: 55 | return self._get_lr(num_updates) 56 | else: 57 | return None 58 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from .scheduler import Scheduler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class TanhLRScheduler(Scheduler): 13 | """ 14 | Hyberbolic-Tangent decay with restarts. 15 | This is described in the paper https://arxiv.org/abs/1806.01593 16 | """ 17 | 18 | def __init__(self, 19 | optimizer: torch.optim.Optimizer, 20 | t_initial: int, 21 | lb: float = -6., 22 | ub: float = 4., 23 | t_mul: float = 1., 24 | lr_min: float = 0., 25 | decay_rate: float = 1., 26 | warmup_t=0, 27 | warmup_lr_init=0, 28 | warmup_prefix=False, 29 | cycle_limit=0, 30 | t_in_epochs=True, 31 | noise_range_t=None, 32 | noise_pct=0.67, 33 | noise_std=1.0, 34 | noise_seed=42, 35 | initialize=True) -> None: 36 | super().__init__( 37 | optimizer, param_group_field="lr", 38 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 39 | initialize=initialize) 40 | 41 | assert t_initial > 0 42 | assert lr_min >= 0 43 | assert lb < ub 44 | assert cycle_limit >= 0 45 | assert warmup_t >= 0 46 | assert warmup_lr_init >= 0 47 | self.lb = lb 48 | self.ub = ub 49 | self.t_initial = t_initial 50 | self.t_mul = t_mul 51 | self.lr_min = lr_min 52 | self.decay_rate = decay_rate 53 | self.cycle_limit = cycle_limit 54 | self.warmup_t = warmup_t 55 | self.warmup_lr_init = warmup_lr_init 56 | self.warmup_prefix = warmup_prefix 57 | self.t_in_epochs = t_in_epochs 58 | if self.warmup_t: 59 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 60 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 61 | super().update_groups(self.warmup_lr_init) 62 | else: 63 | self.warmup_steps = [1 for _ in self.base_values] 64 | 65 | def _get_lr(self, t): 66 | if t < self.warmup_t: 67 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 68 | else: 69 | if self.warmup_prefix: 70 | t = t - self.warmup_t 71 | 72 | if self.t_mul != 1: 73 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 74 | t_i = self.t_mul ** i * self.t_initial 75 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 76 | else: 77 | i = t // self.t_initial 78 | t_i = self.t_initial 79 | t_curr = t - (self.t_initial * i) 80 | 81 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 82 | gamma = self.decay_rate ** i 83 | lr_min = self.lr_min * gamma 84 | lr_max_values = [v * gamma for v in self.base_values] 85 | 86 | tr = t_curr / t_i 87 | lrs = [ 88 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 89 | for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 93 | return lrs 94 | 95 | def get_epoch_values(self, epoch: int): 96 | if self.t_in_epochs: 97 | return self._get_lr(epoch) 98 | else: 99 | return None 100 | 101 | def get_update_values(self, num_updates: int): 102 | if not self.t_in_epochs: 103 | return self._get_lr(num_updates) 104 | else: 105 | return None 106 | 107 | def get_cycle_length(self, cycles=0): 108 | if not cycles: 109 | cycles = self.cycle_limit 110 | assert cycles > 0 111 | if self.t_mul == 1.0: 112 | return self.t_initial * cycles 113 | else: 114 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 115 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.18' 2 | --------------------------------------------------------------------------------