├── LICENSE ├── MANIFEST.in ├── README.md ├── avg_checkpoints.py ├── clean_checkpoint.py ├── convert └── convert_from_mxnet.py ├── distributed_train.sh ├── docs ├── archived_changes.md ├── changes.md ├── feature_extraction.md ├── index.md ├── javascripts │ └── tables.js ├── models.md ├── results.md ├── scripts.md └── training_hparam_examples.md ├── fashion-product-images ├── all.csv ├── images │ └── 1163.jpg ├── test.csv ├── train.csv └── val.csv ├── hubconf.py ├── imgs ├── 20210910-211332-efficientnet_b0-224-trainingProcessAccLoss.png ├── 20210910-211332-efficientnet_b0-224-trainingProcessLosses.png ├── 20210910-211332-efficientnet_b0-224.jpg └── experiments.jpg ├── inference.py ├── mkdocs.yml ├── notebooks ├── EffResNetComparison.ipynb └── GeneralizationToImageNetV2.ipynb ├── requirements-docs.txt ├── requirements-sotabench.txt ├── requirements.txt ├── results ├── README.md ├── generate_csv_results.py ├── imagenet21k_goog_synsets.txt ├── imagenet_a_indices.txt ├── imagenet_a_synsets.txt ├── imagenet_r_indices.txt ├── imagenet_r_synsets.txt ├── imagenet_real_labels.json ├── imagenet_synsets.txt ├── results-imagenet-a-clean.csv ├── results-imagenet-a.csv ├── results-imagenet-r-clean.csv ├── results-imagenet-r.csv ├── results-imagenet-real.csv ├── results-imagenet.csv ├── results-imagenetv2-matched-frequency.csv └── results-sketch.csv ├── setup.cfg ├── setup.py ├── simclr └── modules │ ├── __init__.py │ ├── gather.py │ └── nt_xent.py ├── sotabench.py ├── sotabench_setup.sh ├── tests ├── __init__.py ├── test_layers.py └── test_models.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── __init__original_version_by_ross.py │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── dataset_original_version_by_ross.py │ ├── dataset_without_simclr.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── loader_original_version_by_ross.py │ ├── loader_without_simclr.py │ ├── mixup.py │ ├── parsers │ │ ├── __init__.py │ │ ├── class_map.py │ │ ├── constants.py │ │ ├── parser.py │ │ ├── parser_factory.py │ │ ├── parser_image_folder.py │ │ ├── parser_image_in_tar.py │ │ ├── parser_image_tar.py │ │ └── parser_tfds.py │ ├── random_erasing.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── asymmetric_loss.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── __init__original_version_by_ross.py │ ├── byobnet.py │ ├── cspnet.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── efficientnet_original_version_by_ross.py │ ├── factory.py │ ├── features.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── helpers.py │ ├── hrnet.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── anti_aliasing.py │ │ ├── blur_pool.py │ │ ├── cbam.py │ │ ├── classifier.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm_act.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── linear.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── pool2d_same.py │ │ ├── se.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── std_conv.py │ │ ├── test_time_pool.py │ │ └── weight_init.py │ ├── mobilenetv3.py │ ├── multi_label_model.py │ ├── multi_label_model_without_simclr.py │ ├── nasnet.py │ ├── nfnet.py │ ├── pnasnet.py │ ├── pruned │ │ ├── ecaresnet101d_pruned.txt │ │ ├── ecaresnet50d_pruned.txt │ │ ├── efficientnet_b1_pruned.txt │ │ ├── efficientnet_b2_pruned.txt │ │ └── efficientnet_b3_pruned.txt │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── resnetv2.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── tresnet.py │ ├── vgg.py │ ├── vision_transformer.py │ ├── vovnet.py │ ├── xception.py │ └── xception_aligned.py ├── optim │ ├── __init__.py │ ├── adafactor.py │ ├── adahessian.py │ ├── adamp.py │ ├── adamp_original_version_by_ross.py │ ├── adamw.py │ ├── centralization.py │ ├── lookahead.py │ ├── nadam.py │ ├── novograd.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── optim_factory_original_version_by_ross.py │ ├── radam.py │ ├── rmsprop_tf.py │ └── sgdp.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── plateau_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils │ ├── __init__.py │ ├── agc.py │ ├── checkpoint_saver.py │ ├── clip_grad.py │ ├── cuda.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ ├── summary.py │ ├── summary_original_version_by_ross.py │ └── summary_without_simclr.py └── version.py ├── train.py ├── train_original_version_by_ross.py ├── train_without_simclr.py ├── validate.py ├── validate_original_version_by_ross.py └── validate_without_simclr.py /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include timm/models/pruned/*.txt 2 | 3 | -------------------------------------------------------------------------------- /clean_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ Checkpoint Cleaning Script 3 | 4 | Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc. 5 | and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256 6 | calculation for model zoo compatibility. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import torch 11 | import argparse 12 | import os 13 | import hashlib 14 | import shutil 15 | from collections import OrderedDict 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') 18 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 19 | help='path to latest checkpoint (default: none)') 20 | parser.add_argument('--output', default='', type=str, metavar='PATH', 21 | help='output path') 22 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 23 | help='use ema version of weights if present') 24 | parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', 25 | help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') 26 | 27 | _TEMP_NAME = './_checkpoint.pth' 28 | 29 | 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | if os.path.exists(args.output): 34 | print("Error: Output filename ({}) already exists.".format(args.output)) 35 | exit(1) 36 | 37 | # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save 38 | if args.checkpoint and os.path.isfile(args.checkpoint): 39 | print("=> Loading checkpoint '{}'".format(args.checkpoint)) 40 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 41 | 42 | new_state_dict = OrderedDict() 43 | if isinstance(checkpoint, dict): 44 | state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' 45 | if state_dict_key in checkpoint: 46 | state_dict = checkpoint[state_dict_key] 47 | else: 48 | state_dict = checkpoint 49 | else: 50 | assert False 51 | for k, v in state_dict.items(): 52 | if args.clean_aux_bn and 'aux_bn' in k: 53 | # If all aux_bn keys are removed, the SplitBN layers will end up as normal and 54 | # load with the unmodified model using BatchNorm2d. 55 | continue 56 | name = k[7:] if k.startswith('module') else k 57 | new_state_dict[name] = v 58 | print("=> Loaded state_dict from '{}'".format(args.checkpoint)) 59 | 60 | try: 61 | torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) 62 | except: 63 | torch.save(new_state_dict, _TEMP_NAME) 64 | 65 | with open(_TEMP_NAME, 'rb') as f: 66 | sha_hash = hashlib.sha256(f.read()).hexdigest() 67 | 68 | if args.output: 69 | checkpoint_root, checkpoint_base = os.path.split(args.output) 70 | checkpoint_base = os.path.splitext(checkpoint_base)[0] 71 | else: 72 | checkpoint_root = '' 73 | checkpoint_base = os.path.splitext(args.checkpoint)[0] 74 | final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' 75 | shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) 76 | print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) 77 | else: 78 | print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /convert/convert_from_mxnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import os 4 | 5 | import mxnet as mx 6 | import gluoncv 7 | import torch 8 | from timm import create_model 9 | 10 | parser = argparse.ArgumentParser(description='Convert from MXNet') 11 | parser.add_argument('--model', default='all', type=str, metavar='MODEL', 12 | help='Name of model to train (default: "all"') 13 | 14 | 15 | def convert(mxnet_name, torch_name): 16 | # download and load the pre-trained model 17 | net = gluoncv.model_zoo.get_model(mxnet_name, pretrained=True) 18 | 19 | # create corresponding torch model 20 | torch_net = create_model(torch_name) 21 | 22 | mxp = [(k, v) for k, v in net.collect_params().items() if 'running' not in k] 23 | torchp = list(torch_net.named_parameters()) 24 | torch_params = {} 25 | 26 | # convert parameters 27 | # NOTE: we are relying on the fact that the order of parameters 28 | # are usually exactly the same between these models, thus no key name mapping 29 | # is necessary. Asserts will trip if this is not the case. 30 | for (tn, tv), (mn, mv) in zip(torchp, mxp): 31 | m_split = mn.split('_') 32 | t_split = tn.split('.') 33 | print(t_split, m_split) 34 | print(tv.shape, mv.shape) 35 | 36 | # ensure ordering of BN params match since their sizes are not specific 37 | if m_split[-1] == 'gamma': 38 | assert t_split[-1] == 'weight' 39 | if m_split[-1] == 'beta': 40 | assert t_split[-1] == 'bias' 41 | 42 | # ensure shapes match 43 | assert all(t == m for t, m in zip(tv.shape, mv.shape)) 44 | 45 | torch_tensor = torch.from_numpy(mv.data().asnumpy()) 46 | torch_params[tn] = torch_tensor 47 | 48 | # convert buffers (batch norm running stats) 49 | mxb = [(k, v) for k, v in net.collect_params().items() if any(x in k for x in ['running_mean', 'running_var'])] 50 | torchb = [(k, v) for k, v in torch_net.named_buffers() if 'num_batches' not in k] 51 | for (tn, tv), (mn, mv) in zip(torchb, mxb): 52 | print(tn, mn) 53 | print(tv.shape, mv.shape) 54 | 55 | # ensure ordering of BN params match since their sizes are not specific 56 | if 'running_var' in tn: 57 | assert 'running_var' in mn 58 | if 'running_mean' in tn: 59 | assert 'running_mean' in mn 60 | 61 | torch_tensor = torch.from_numpy(mv.data().asnumpy()) 62 | torch_params[tn] = torch_tensor 63 | 64 | torch_net.load_state_dict(torch_params) 65 | torch_filename = './%s.pth' % torch_name 66 | torch.save(torch_net.state_dict(), torch_filename) 67 | with open(torch_filename, 'rb') as f: 68 | sha_hash = hashlib.sha256(f.read()).hexdigest() 69 | final_filename = os.path.splitext(torch_filename)[0] + '-' + sha_hash[:8] + '.pth' 70 | os.rename(torch_filename, final_filename) 71 | print("=> Saved converted model to '{}, SHA256: {}'".format(final_filename, sha_hash)) 72 | 73 | 74 | def map_mx_to_torch_model(mx_name): 75 | torch_name = mx_name.lower() 76 | if torch_name.startswith('se_'): 77 | torch_name = torch_name.replace('se_', 'se') 78 | elif torch_name.startswith('senet_'): 79 | torch_name = torch_name.replace('senet_', 'senet') 80 | elif torch_name.startswith('inceptionv3'): 81 | torch_name = torch_name.replace('inceptionv3', 'inception_v3') 82 | torch_name = 'gluon_' + torch_name 83 | return torch_name 84 | 85 | 86 | ALL = ['resnet18_v1b', 'resnet34_v1b', 'resnet50_v1b', 'resnet101_v1b', 'resnet152_v1b', 87 | 'resnet50_v1c', 'resnet101_v1c', 'resnet152_v1c', 'resnet50_v1d', 'resnet101_v1d', 'resnet152_v1d', 88 | #'resnet50_v1e', 'resnet101_v1e', 'resnet152_v1e', 89 | 'resnet50_v1s', 'resnet101_v1s', 'resnet152_v1s', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 90 | 'se_resnext50_32x4d', 'se_resnext101_32x4d', 'se_resnext101_64x4d', 'senet_154', 'inceptionv3'] 91 | 92 | 93 | def main(): 94 | args = parser.parse_args() 95 | 96 | if not args.model or args.model == 'all': 97 | for mx_model in ALL: 98 | torch_model = map_mx_to_torch_model(mx_model) 99 | convert(mx_model, torch_model) 100 | else: 101 | mx_model = args.model 102 | torch_model = map_mx_to_torch_model(mx_model) 103 | convert(mx_model, torch_model) 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Install 4 | 5 | The library can be installed with pip: 6 | 7 | ``` 8 | pip install timm 9 | ``` 10 | 11 | !!! info "Conda Environment" 12 | All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x., 3.8.x. 13 | 14 | Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. 15 | 16 | PyTorch versions 1.4, 1.5.x, 1.6, and 1.7 have been tested with this code. 17 | 18 | I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: 19 | ``` 20 | conda create -n torch-env 21 | conda activate torch-env 22 | conda install -c pytorch pytorch torchvision cudatoolkit=11 23 | conda install pyyaml 24 | ``` 25 | 26 | ## Load a Pretrained Model 27 | 28 | Pretrained models can be loaded using `timm.create_model` 29 | 30 | ```python 31 | import timm 32 | 33 | m = timm.create_model('mobilenetv3_large_100', pretrained=True) 34 | m.eval() 35 | ``` 36 | 37 | ## List Models with Pretrained Weights 38 | ```python 39 | import timm 40 | from pprint import pprint 41 | model_names = timm.list_models(pretrained=True) 42 | pprint(model_names) 43 | >>> ['adv_inception_v3', 44 | 'cspdarknet53', 45 | 'cspresnext50', 46 | 'densenet121', 47 | 'densenet161', 48 | 'densenet169', 49 | 'densenet201', 50 | 'densenetblur121d', 51 | 'dla34', 52 | 'dla46_c', 53 | ... 54 | ] 55 | ``` 56 | 57 | ## List Model Architectures by Wildcard 58 | ```python 59 | import timm 60 | from pprint import pprint 61 | model_names = timm.list_models('*resne*t*') 62 | pprint(model_names) 63 | >>> ['cspresnet50', 64 | 'cspresnet50d', 65 | 'cspresnet50w', 66 | 'cspresnext50', 67 | ... 68 | ] 69 | ``` 70 | -------------------------------------------------------------------------------- /docs/javascripts/tables.js: -------------------------------------------------------------------------------- 1 | app.location$.subscribe(function() { 2 | var tables = document.querySelectorAll("article table") 3 | tables.forEach(function(table) { 4 | new Tablesort(table) 5 | }) 6 | }) -------------------------------------------------------------------------------- /docs/scripts.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release. 3 | 4 | The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on 5 | [NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples). 6 | 7 | ## Training Script 8 | 9 | The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder. 10 | 11 | To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value: 12 | 13 | `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` 14 | 15 | NOTE: It is recommended to use PyTorch 1.7+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. 16 | 17 | ## Validation / Inference Scripts 18 | 19 | Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script. 20 | 21 | To validate with the model's pretrained weights (if they exist): 22 | 23 | `python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained` 24 | 25 | To run inference from a checkpoint: 26 | 27 | `python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar` -------------------------------------------------------------------------------- /fashion-product-images/images/1163.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/fashion-product-images/images/1163.jpg -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | from timm.models import registry 3 | 4 | globals().update(registry._model_entrypoints) 5 | -------------------------------------------------------------------------------- /imgs/20210910-211332-efficientnet_b0-224-trainingProcessAccLoss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/imgs/20210910-211332-efficientnet_b0-224-trainingProcessAccLoss.png -------------------------------------------------------------------------------- /imgs/20210910-211332-efficientnet_b0-224-trainingProcessLosses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/imgs/20210910-211332-efficientnet_b0-224-trainingProcessLosses.png -------------------------------------------------------------------------------- /imgs/20210910-211332-efficientnet_b0-224.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/imgs/20210910-211332-efficientnet_b0-224.jpg -------------------------------------------------------------------------------- /imgs/experiments.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/imgs/experiments.jpg -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: 'Pytorch Image Models' 2 | site_description: 'Pretained Image Recognition Models' 3 | repo_name: 'rwightman/pytorch-image-models' 4 | repo_url: 'https://github.com/rwightman/pytorch-image-models' 5 | nav: 6 | - index.md 7 | - models.md 8 | - results.md 9 | - scripts.md 10 | - training_hparam_examples.md 11 | - feature_extraction.md 12 | - changes.md 13 | - archived_changes.md 14 | theme: 15 | name: 'material' 16 | feature: 17 | tabs: false 18 | extra_javascript: 19 | - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML' 20 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js 21 | - javascripts/tables.js 22 | markdown_extensions: 23 | - codehilite: 24 | linenums: true 25 | - admonition 26 | - pymdownx.arithmatex 27 | - pymdownx.betterem: 28 | smart_enable: all 29 | - pymdownx.caret 30 | - pymdownx.critic 31 | - pymdownx.details 32 | - pymdownx.emoji: 33 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 34 | - pymdownx.inlinehilite 35 | - pymdownx.magiclink 36 | - pymdownx.mark 37 | - pymdownx.smartsymbols 38 | - pymdownx.superfences 39 | - pymdownx.tasklist: 40 | custom_checkbox: true 41 | - pymdownx.tilde 42 | - mdx_truly_sane_lists 43 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.1.2 2 | mkdocs-material==5.4.0 3 | mdx_truly_sane_lists -------------------------------------------------------------------------------- /requirements-sotabench.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision==0.5.0 3 | pyyaml 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | pyyaml 4 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # Validation Results 2 | 3 | This folder contains validation results for the models in this collection having pretrained weights. Since the focus for this repository is currently ImageNet-1k classification, all of the results are based on datasets compatible with ImageNet-1k classes. 4 | 5 | ## Datasets 6 | 7 | There are currently results for the ImageNet validation set and 5 additional test / label sets. 8 | 9 | The test set results include rank and top-1/top-5 differences from clean validation. For the "Real Labels", ImageNetV2, and Sketch test sets, the differences were calculated against the full 1000 class ImageNet-1k validation set. For both the Adversarial and Rendition sets, the differences were calculated against 'clean' runs on the ImageNet-1k validation set with the same 200 classes used in each test set respectively. 10 | 11 | ### ImageNet Validation - [`results-imagenet.csv`](results-imagenet.csv) 12 | 13 | The standard 50,000 image ImageNet-1k validation set. Model selection during training utilizes this validation set, so it is not a true test set. Question: Does anyone have the official ImageNet-1k test set classification labels now that challenges are done? 14 | 15 | * Source: http://image-net.org/challenges/LSVRC/2012/index 16 | * Paper: "ImageNet Large Scale Visual Recognition Challenge" - https://arxiv.org/abs/1409.0575 17 | 18 | ### ImageNet-"Real Labels" - [`results-imagenet-real.csv`](results-imagenet-real.csv) 19 | 20 | The usual ImageNet-1k validation set with a fresh new set of labels intended to improve on mistakes in the original annotation process. 21 | 22 | * Source: https://github.com/google-research/reassessed-imagenet 23 | * Paper: "Are we done with ImageNet?" - https://arxiv.org/abs/2006.07159 24 | 25 | ### ImageNetV2 Matched Frequency - [`results-imagenetv2-matched-frequency.csv`](results-imagenetv2-matched-frequency.csv) 26 | 27 | An ImageNet test set of 10,000 images sampled from new images roughly 10 years after the original. Care was taken to replicate the original ImageNet curation/sampling process. 28 | 29 | * Source: https://github.com/modestyachts/ImageNetV2 30 | * Paper: "Do ImageNet Classifiers Generalize to ImageNet?" - https://arxiv.org/abs/1902.10811 31 | 32 | ### ImageNet-Sketch - [`results-sketch.csv`](results-sketch.csv) 33 | 34 | 50,000 non photographic (or photos of such) images (sketches, doodles, mostly monochromatic) covering all 1000 ImageNet classes. 35 | 36 | * Source: https://github.com/HaohanWang/ImageNet-Sketch 37 | * Paper: "Learning Robust Global Representations by Penalizing Local Predictive Power" - https://arxiv.org/abs/1905.13549 38 | 39 | ### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv) 40 | 41 | A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1. 42 | 43 | For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv) 44 | 45 | * Source: https://github.com/hendrycks/natural-adv-examples 46 | * Paper: "Natural Adversarial Examples" - https://arxiv.org/abs/1907.07174 47 | 48 | 49 | ### ImageNet-Rendition - [`results-imagenet-r.csv`](results-imagenet-r.csv) 50 | 51 | Renditions of 200 ImageNet classes resulting in 30,000 images for testing robustness. 52 | 53 | For clean validation with same 200 classes, see [`results-imagenet-r-clean.csv`](results-imagenet-r-clean.csv) 54 | 55 | * Source: https://github.com/hendrycks/imagenet-r 56 | * Paper: "The Many Faces of Robustness" - https://arxiv.org/abs/2006.16241 57 | 58 | ## TODO 59 | * Explore adding a reduced version of ImageNet-C (Corruptions) and ImageNet-P (Perturbations) from https://github.com/hendrycks/robustness. The originals are huge and image size specific. 60 | -------------------------------------------------------------------------------- /results/generate_csv_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | results = { 6 | 'results-imagenet.csv': [ 7 | 'results-imagenet-real.csv', 8 | 'results-imagenetv2-matched-frequency.csv', 9 | 'results-sketch.csv' 10 | ], 11 | 'results-imagenet-a-clean.csv': [ 12 | 'results-imagenet-a.csv', 13 | ], 14 | 'results-imagenet-r-clean.csv': [ 15 | 'results-imagenet-r.csv', 16 | ], 17 | } 18 | 19 | 20 | def diff(base_df, test_csv): 21 | base_models = base_df['model'].values 22 | test_df = pd.read_csv(test_csv) 23 | test_models = test_df['model'].values 24 | 25 | rank_diff = np.zeros_like(test_models, dtype='object') 26 | top1_diff = np.zeros_like(test_models, dtype='object') 27 | top5_diff = np.zeros_like(test_models, dtype='object') 28 | 29 | for rank, model in enumerate(test_models): 30 | if model in base_models: 31 | base_rank = int(np.where(base_models == model)[0]) 32 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank] 33 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank] 34 | 35 | # rank_diff 36 | if rank == base_rank: 37 | rank_diff[rank] = f'0' 38 | elif rank > base_rank: 39 | rank_diff[rank] = f'-{rank - base_rank}' 40 | else: 41 | rank_diff[rank] = f'+{base_rank - rank}' 42 | 43 | # top1_diff 44 | if top1_d >= .0: 45 | top1_diff[rank] = f'+{top1_d:.3f}' 46 | else: 47 | top1_diff[rank] = f'-{abs(top1_d):.3f}' 48 | 49 | # top5_diff 50 | if top5_d >= .0: 51 | top5_diff[rank] = f'+{top5_d:.3f}' 52 | else: 53 | top5_diff[rank] = f'-{abs(top5_d):.3f}' 54 | 55 | else: 56 | rank_diff[rank] = '' 57 | top1_diff[rank] = '' 58 | top5_diff[rank] = '' 59 | 60 | test_df['top1_diff'] = top1_diff 61 | test_df['top5_diff'] = top5_diff 62 | test_df['rank_diff'] = rank_diff 63 | 64 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) 65 | test_df.sort_values('top1', ascending=False, inplace=True) 66 | test_df.to_csv(test_csv, index=False, float_format='%.3f') 67 | 68 | 69 | for base_results, test_results in results.items(): 70 | base_df = pd.read_csv(base_results) 71 | base_df.sort_values('top1', ascending=False, inplace=True) 72 | for test_csv in test_results: 73 | diff(base_df, test_csv) 74 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format) 75 | base_df.to_csv(base_results, index=False, float_format='%.3f') 76 | -------------------------------------------------------------------------------- /results/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /results/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /results/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /results/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [dist_conda] 2 | 3 | conda_name_differences = 'torch:pytorch' 4 | channels = pytorch 5 | noarch = True 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('timm/version.py').read()) 14 | setup( 15 | name='timm', 16 | version=__version__, 17 | description='(Unofficial) PyTorch Image Models', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/pytorch-image-models', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', 44 | packages=find_packages(exclude=['convert', 'tests', 'results']), 45 | include_package_data=True, 46 | install_requires=['torch >= 1.4', 'torchvision'], 47 | python_requires='>=3.6', 48 | ) 49 | -------------------------------------------------------------------------------- /simclr/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .nt_xent import NT_Xent 2 | from .gather import GatherLayer 3 | -------------------------------------------------------------------------------- /simclr/modules/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gather tensors from all process, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 12 | dist.all_gather(output, input) 13 | return tuple(output) 14 | 15 | @staticmethod 16 | def backward(ctx, *grads): 17 | (input,) = ctx.saved_tensors 18 | grad_out = torch.zeros_like(input) 19 | grad_out[:] = grads[dist.get_rank()] 20 | return grad_out 21 | -------------------------------------------------------------------------------- /simclr/modules/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from .gather import GatherLayer 5 | 6 | 7 | class NT_Xent(nn.Module): 8 | def __init__(self, batch_size, temperature, world_size): 9 | super(NT_Xent, self).__init__() 10 | self.batch_size = batch_size 11 | self.temperature = temperature 12 | self.world_size = world_size 13 | 14 | self.mask = self.mask_correlated_samples(batch_size, world_size) 15 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 16 | self.similarity_f = nn.CosineSimilarity(dim=2) 17 | 18 | def mask_correlated_samples(self, batch_size, world_size): 19 | N = 2 * batch_size * world_size 20 | mask = torch.ones((N, N), dtype=bool) 21 | mask = mask.fill_diagonal_(0) 22 | for i in range(batch_size * world_size): 23 | mask[i, batch_size + i] = 0 24 | mask[batch_size + i, i] = 0 25 | return mask 26 | 27 | def forward(self, z_i, z_j): 28 | """ 29 | We do not sample negative examples explicitly. 30 | Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. 31 | """ 32 | N = 2 * self.batch_size * self.world_size 33 | 34 | z = torch.cat((z_i, z_j), dim=0) 35 | if self.world_size > 1: 36 | z = torch.cat(GatherLayer.apply(z), dim=0) 37 | 38 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature 39 | 40 | sim_i_j = torch.diag(sim, self.batch_size * self.world_size) 41 | sim_j_i = torch.diag(sim, -self.batch_size * self.world_size) 42 | 43 | # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN 44 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 45 | negative_samples = sim[self.mask].reshape(N, -1) 46 | 47 | labels = torch.zeros(N).to(positive_samples.device).long() 48 | logits = torch.cat((positive_samples, negative_samples), dim=1) 49 | loss = self.criterion(logits, labels) 50 | loss /= N 51 | return loss 52 | -------------------------------------------------------------------------------- /sotabench_setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source /workspace/venv/bin/activate 3 | 4 | pip install --upgrade pip 5 | pip install -r requirements-sotabench.txt 6 | 7 | apt-get update 8 | apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev 9 | pip uninstall -y pillow 10 | CFLAGS="${CFLAGS} -mavx2" pip install -U --no-cache-dir --force-reinstall --no-binary :all:--compile https://github.com/mrT23/pillow-simd/zipball/simd/7.0.x 11 | #CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 12 | 13 | # FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work 14 | apt-get install wget 15 | #wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet 16 | wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet 17 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-ruixin/pytorch-image-models-with-simclr/0a3182be5ef6e008b73bb9f9ce756d26f3dab7c0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import platform 5 | import os 6 | 7 | from timm.models.layers import create_act_layer, get_act_layer, set_layer_config 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, act_layer="relu"): 12 | super(MLP, self).__init__() 13 | self.fc1 = nn.Linear(1000, 100) 14 | self.act = create_act_layer(act_layer, inplace=True) 15 | self.fc2 = nn.Linear(100, 10) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.act(x) 20 | x = self.fc2(x) 21 | return x 22 | 23 | 24 | def _run_act_layer_grad(act_type): 25 | x = torch.rand(10, 1000) * 10 26 | m = MLP(act_layer=act_type) 27 | 28 | def _run(x, act_layer=''): 29 | if act_layer: 30 | # replace act layer if set 31 | m.act = create_act_layer(act_layer, inplace=True) 32 | out = m(x) 33 | l = (out - 0).pow(2).sum() 34 | return l 35 | 36 | out_me = _run(x) 37 | 38 | with set_layer_config(scriptable=True): 39 | out_jit = _run(x, act_type) 40 | 41 | assert torch.isclose(out_jit, out_me) 42 | 43 | with set_layer_config(no_jit=True): 44 | out_basic = _run(x, act_type) 45 | 46 | assert torch.isclose(out_basic, out_jit) 47 | 48 | 49 | def test_swish_grad(): 50 | for _ in range(100): 51 | _run_act_layer_grad('swish') 52 | 53 | 54 | def test_mish_grad(): 55 | for _ in range(100): 56 | _run_act_layer_grad('mish') 57 | 58 | 59 | def test_hard_sigmoid_grad(): 60 | for _ in range(100): 61 | _run_act_layer_grad('hard_sigmoid') 62 | 63 | 64 | def test_hard_swish_grad(): 65 | for _ in range(100): 66 | _run_act_layer_grad('hard_swish') 67 | 68 | 69 | def test_hard_mish_grad(): 70 | for _ in range(100): 71 | _run_act_layer_grad('hard_mish') 72 | -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable 4 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser 10 | from .real_labels import RealLabelsImagenet 11 | from .transforms import * 12 | from .transforms_factory import create_transform 13 | 14 | from .dataset import DatasetAttributes, DatasetML # ================================ 15 | -------------------------------------------------------------------------------- /timm/data/__init__original_version_by_ross.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .loader import create_loader 8 | from .mixup import Mixup, FastCollateMixup 9 | from .parsers import create_parser 10 | from .real_labels import RealLabelsImagenet 11 | from .transforms import * 12 | from .transforms_factory import create_transform -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | _logger = logging.getLogger(__name__) 6 | 7 | 8 | def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True): 9 | new_config = {} 10 | default_cfg = default_cfg 11 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 12 | default_cfg = model.default_cfg 13 | 14 | # Resolve input/image size 15 | in_chans = 3 16 | if 'chans' in args and args['chans'] is not None: 17 | in_chans = args['chans'] 18 | 19 | input_size = (in_chans, 224, 224) 20 | if 'input_size' in args and args['input_size'] is not None: 21 | assert isinstance(args['input_size'], (tuple, list)) 22 | assert len(args['input_size']) == 3 23 | input_size = tuple(args['input_size']) 24 | in_chans = input_size[0] # input_size overrides in_chans 25 | elif 'img_size' in args and args['img_size'] is not None: 26 | assert isinstance(args['img_size'], int) 27 | input_size = (in_chans, args['img_size'], args['img_size']) 28 | else: 29 | if use_test_size and 'test_input_size' in default_cfg: 30 | input_size = default_cfg['test_input_size'] 31 | elif 'input_size' in default_cfg: 32 | input_size = default_cfg['input_size'] 33 | new_config['input_size'] = input_size 34 | 35 | # resolve interpolation method 36 | new_config['interpolation'] = 'bicubic' 37 | if 'interpolation' in args and args['interpolation']: 38 | new_config['interpolation'] = args['interpolation'] 39 | elif 'interpolation' in default_cfg: 40 | new_config['interpolation'] = default_cfg['interpolation'] 41 | 42 | # resolve dataset + model mean for normalization 43 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 44 | if 'mean' in args and args['mean'] is not None: 45 | mean = tuple(args['mean']) 46 | if len(mean) == 1: 47 | mean = tuple(list(mean) * in_chans) 48 | else: 49 | assert len(mean) == in_chans 50 | new_config['mean'] = mean 51 | elif 'mean' in default_cfg: 52 | new_config['mean'] = default_cfg['mean'] 53 | 54 | # resolve dataset + model std deviation for normalization 55 | new_config['std'] = IMAGENET_DEFAULT_STD 56 | if 'std' in args and args['std'] is not None: 57 | std = tuple(args['std']) 58 | if len(std) == 1: 59 | std = tuple(list(std) * in_chans) 60 | else: 61 | assert len(std) == in_chans 62 | new_config['std'] = std 63 | elif 'std' in default_cfg: 64 | new_config['std'] = default_cfg['std'] 65 | 66 | # resolve default crop percentage 67 | new_config['crop_pct'] = DEFAULT_CROP_PCT 68 | if 'crop_pct' in args and args['crop_pct'] is not None: 69 | new_config['crop_pct'] = args['crop_pct'] 70 | elif 'crop_pct' in default_cfg: 71 | new_config['crop_pct'] = default_cfg['crop_pct'] 72 | 73 | if verbose: 74 | _logger.info('Data processing configuration for current model + dataset:') 75 | for n, v in new_config.items(): 76 | _logger.info('\t%s: %s' % (n, str(v))) 77 | 78 | return new_config 79 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/dataset_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .dataset import IterableImageDataset, ImageDataset 4 | 5 | 6 | def _search_split(root, split): 7 | # look for sub-folder with name of split in root and use that if it exists 8 | split_name = split.split('[')[0] 9 | try_root = os.path.join(root, split_name) 10 | if os.path.exists(try_root): 11 | return try_root 12 | if split_name == 'validation': 13 | try_root = os.path.join(root, 'val') 14 | if os.path.exists(try_root): 15 | return try_root 16 | return root 17 | 18 | 19 | def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): 20 | name = name.lower() 21 | if name.startswith('tfds'): 22 | ds = IterableImageDataset( 23 | root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) 24 | else: 25 | # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future 26 | if search_split and os.path.isdir(root): 27 | root = _search_split(root, split) 28 | ds = ImageDataset(root, parser=name, **kwargs) 29 | return ds 30 | -------------------------------------------------------------------------------- /timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /timm/data/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser_factory import create_parser 2 | -------------------------------------------------------------------------------- /timm/data/parsers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def load_class_map(filename, root=''): 5 | class_map_path = filename 6 | if not os.path.exists(class_map_path): 7 | class_map_path = os.path.join(root, filename) 8 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename 9 | class_map_ext = os.path.splitext(filename)[-1].lower() 10 | if class_map_ext == '.txt': 11 | with open(class_map_path) as f: 12 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 13 | else: 14 | assert False, 'Unsupported class map extension' 15 | return class_to_idx 16 | 17 | -------------------------------------------------------------------------------- /timm/data/parsers/constants.py: -------------------------------------------------------------------------------- 1 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') 2 | -------------------------------------------------------------------------------- /timm/data/parsers/parser.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Parser: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .parser_image_folder import ParserImageFolder 4 | from .parser_image_tar import ParserImageTar 5 | from .parser_image_in_tar import ParserImageInTar 6 | 7 | 8 | def create_parser(name, root, split='train', **kwargs): 9 | name = name.lower() 10 | name = name.split('/', 2) 11 | prefix = '' 12 | if len(name) > 1: 13 | prefix = name[0] 14 | name = name[-1] 15 | 16 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 17 | # explicitly select other options shortly 18 | if prefix == 'tfds': 19 | from .parser_tfds import ParserTfds # defer tensorflow import 20 | parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) 21 | else: 22 | assert os.path.exists(root) 23 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 24 | # FIXME support split here, in parser? 25 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 26 | parser = ParserImageInTar(root, **kwargs) 27 | else: 28 | parser = ParserImageFolder(root, **kwargs) 29 | return parser 30 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads images from folders 2 | 3 | Folders are scannerd recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | 10 | from timm.utils.misc import natural_key 11 | 12 | from .parser import Parser 13 | from .class_map import load_class_map 14 | from .constants import IMG_EXTENSIONS 15 | 16 | 17 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 18 | labels = [] 19 | filenames = [] 20 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 21 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 22 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 23 | for f in files: 24 | base, ext = os.path.splitext(f) 25 | if ext.lower() in types: 26 | filenames.append(os.path.join(root, f)) 27 | labels.append(label) 28 | if class_to_idx is None: 29 | # building class index 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 34 | if sort: 35 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 36 | return images_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageFolder(Parser): 40 | 41 | def __init__( 42 | self, 43 | root, 44 | class_map=''): 45 | super().__init__() 46 | 47 | self.root = root 48 | class_to_idx = None 49 | if class_map: 50 | class_to_idx = load_class_map(class_map, root) 51 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 52 | if len(self.samples) == 0: 53 | raise RuntimeError( 54 | f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') 55 | 56 | def __getitem__(self, index): 57 | path, target = self.samples[index] 58 | return open(path, 'rb'), target 59 | 60 | def __len__(self): 61 | return len(self.samples) 62 | 63 | def _filename(self, index, basename=False, absolute=False): 64 | filename = self.samples[index][0] 65 | if basename: 66 | filename = os.path.basename(filename) 67 | elif not absolute: 68 | filename = os.path.relpath(filename, self.root) 69 | return filename 70 | -------------------------------------------------------------------------------- /timm/data/parsers/parser_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset parser that reads single tarfile based datasets 2 | 3 | This parser can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from .parser import Parser 12 | from .class_map import load_class_map 13 | from .constants import IMG_EXTENSIONS 14 | from timm.utils.misc import natural_key 15 | 16 | 17 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 18 | files = [] 19 | labels = [] 20 | for ti in tarfile.getmembers(): 21 | if not ti.isfile(): 22 | continue 23 | dirname, basename = os.path.split(ti.path) 24 | label = os.path.basename(dirname) 25 | ext = os.path.splitext(basename)[1] 26 | if ext.lower() in IMG_EXTENSIONS: 27 | files.append(ti) 28 | labels.append(label) 29 | if class_to_idx is None: 30 | unique_labels = set(labels) 31 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 32 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 33 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 34 | if sort: 35 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 36 | return tarinfo_and_targets, class_to_idx 37 | 38 | 39 | class ParserImageTar(Parser): 40 | """ Single tarfile dataset where classes are mapped to folders within tar 41 | NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can 42 | operate on folders of tars or tars in tars. 43 | """ 44 | def __init__(self, root, class_map=''): 45 | super().__init__() 46 | 47 | class_to_idx = None 48 | if class_map: 49 | class_to_idx = load_class_map(class_map, root) 50 | assert os.path.isfile(root) 51 | self.root = root 52 | 53 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 54 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 55 | self.imgs = self.samples 56 | self.tarfile = None # lazy init in __getitem__ 57 | 58 | def __getitem__(self, index): 59 | if self.tarfile is None: 60 | self.tarfile = tarfile.open(self.root) 61 | tarinfo, target = self.samples[index] 62 | fileobj = self.tarfile.extractfile(tarinfo) 63 | return fileobj, target 64 | 65 | def __len__(self): 66 | return len(self.samples) 67 | 68 | def _filename(self, index, basename=False, absolute=False): 69 | filename = self.samples[index][0].name 70 | if basename: 71 | filename = os.path.basename(filename) 72 | return filename 73 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | 12 | class RealLabelsImagenet: 13 | 14 | def __init__(self, filenames, real_json='real.json', topk=(1, 5)): 15 | with open(real_json) as real_labels: 16 | real_labels = json.load(real_labels) 17 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 18 | self.real_labels = real_labels 19 | self.filenames = filenames 20 | assert len(self.filenames) == len(self.real_labels) 21 | self.topk = topk 22 | self.is_correct = {k: [] for k in topk} 23 | self.sample_idx = 0 24 | 25 | def add_result(self, output): 26 | maxk = max(self.topk) 27 | _, pred_batch = output.topk(maxk, 1, True, True) 28 | pred_batch = pred_batch.cpu().numpy() 29 | for pred in pred_batch: 30 | filename = self.filenames[self.sample_idx] 31 | filename = os.path.basename(filename) 32 | if self.real_labels[filename]: 33 | for k in self.topk: 34 | self.is_correct[k].append( 35 | any([p in self.real_labels[filename] for p in pred[:k]])) 36 | self.sample_idx += 1 37 | 38 | def get_accuracy(self, k=None): 39 | if k is None: 40 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 41 | else: 42 | return float(np.mean(self.is_correct[k])) * 100 43 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 2 | from .jsd import JsdCrossEntropy 3 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch._C.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch._C.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x, target): 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .byobnet import * 2 | from .cspnet import * 3 | from .densenet import * 4 | from .dla import * 5 | from .dpn import * 6 | from .efficientnet import * 7 | from .gluon_resnet import * 8 | from .gluon_xception import * 9 | from .hrnet import * 10 | from .inception_resnet_v2 import * 11 | from .inception_v3 import * 12 | from .inception_v4 import * 13 | from .mobilenetv3 import * 14 | from .nasnet import * 15 | from .nfnet import * 16 | from .pnasnet import * 17 | from .regnet import * 18 | from .res2net import * 19 | from .resnest import * 20 | from .resnet import * 21 | from .resnetv2 import * 22 | from .rexnet import * 23 | from .selecsls import * 24 | from .senet import * 25 | from .sknet import * 26 | from .tresnet import * 27 | from .vgg import * 28 | from .vision_transformer import * 29 | from .vovnet import * 30 | from .xception import * 31 | from .xception_aligned import * 32 | 33 | from .factory import create_model 34 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 35 | from .layers import TestTimePoolHead, apply_test_time_pool 36 | from .layers import convert_splitbn_model 37 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 38 | from .registry import * 39 | 40 | from .multi_label_model import * # ================================ 41 | -------------------------------------------------------------------------------- /timm/models/__init__original_version_by_ross.py: -------------------------------------------------------------------------------- 1 | from .byobnet import * 2 | from .cspnet import * 3 | from .densenet import * 4 | from .dla import * 5 | from .dpn import * 6 | from .efficientnet import * 7 | from .gluon_resnet import * 8 | from .gluon_xception import * 9 | from .hrnet import * 10 | from .inception_resnet_v2 import * 11 | from .inception_v3 import * 12 | from .inception_v4 import * 13 | from .mobilenetv3 import * 14 | from .nasnet import * 15 | from .nfnet import * 16 | from .pnasnet import * 17 | from .regnet import * 18 | from .res2net import * 19 | from .resnest import * 20 | from .resnet import * 21 | from .resnetv2 import * 22 | from .rexnet import * 23 | from .selecsls import * 24 | from .senet import * 25 | from .sknet import * 26 | from .tresnet import * 27 | from .vgg import * 28 | from .vision_transformer import * 29 | from .vovnet import * 30 | from .xception import * 31 | from .xception_aligned import * 32 | 33 | from .factory import create_model 34 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 35 | from .layers import TestTimePoolHead, apply_test_time_pool 36 | from .layers import convert_splitbn_model 37 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 38 | from .registry import * 39 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | from .layers import set_layer_config 4 | 5 | 6 | def create_model( 7 | model_name, 8 | pretrained=False, 9 | checkpoint_path='', 10 | scriptable=None, 11 | exportable=None, 12 | no_jit=None, 13 | **kwargs): 14 | """Create a model 15 | 16 | Args: 17 | model_name (str): name of model to instantiate 18 | pretrained (bool): load pretrained ImageNet-1k weights if true 19 | checkpoint_path (str): path of checkpoint to load after model is initialized 20 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 21 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 22 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 23 | 24 | Keyword Args: 25 | drop_rate (float): dropout rate for training (default: 0.0) 26 | global_pool (str): global pool type (default: 'avg') 27 | **: other kwargs are model specific 28 | """ 29 | model_args = dict(pretrained=pretrained) 30 | 31 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 32 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 33 | if not is_efficientnet: 34 | kwargs.pop('bn_tf', None) 35 | kwargs.pop('bn_momentum', None) 36 | kwargs.pop('bn_eps', None) 37 | 38 | # handle backwards compat with drop_connect -> drop_path change 39 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 40 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 41 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 42 | " Setting drop_path to %f." % drop_connect_rate) 43 | kwargs['drop_path_rate'] = drop_connect_rate 44 | 45 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 46 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 47 | # non-supporting models don't break and default args remain in effect. 48 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 49 | 50 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 51 | if is_model(model_name): 52 | create_fn = model_entrypoint(model_name) 53 | model = create_fn(**model_args, **kwargs) 54 | else: 55 | raise RuntimeError('Unknown model (%s)' % model_name) 56 | 57 | if checkpoint_path: 58 | load_checkpoint(model, checkpoint_path) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .anti_aliasing import AntiAliasDownsampleLayer 5 | from .blur_pool import BlurPool2d 6 | from .classifier import ClassifierHead, create_classifier 7 | from .cond_conv2d import CondConv2d, get_condconv_initializer 8 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 9 | set_layer_config 10 | from .conv2d_same import Conv2dSame, conv2d_same 11 | from .conv_bn_act import ConvBnAct 12 | from .create_act import create_act_layer, get_act_layer, get_act_fn 13 | from .create_attn import get_attn, create_attn 14 | from .create_conv2d import create_conv2d 15 | from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule 18 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 19 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 20 | from .inplace_abn import InplaceAbn 21 | from .linear import Linear 22 | from .mixed_conv2d import MixedConv2d 23 | from .norm_act import BatchNormAct2d, GroupNormAct 24 | from .padding import get_padding, get_same_padding, pad_same 25 | from .pool2d_same import AvgPool2dSame, create_pool2d 26 | from .se import SEModule 27 | from .selective_kernel import SelectiveKernelConv 28 | from .separable_conv import SeparableConv2d, SeparableConvBnAct 29 | from .space_to_depth import SpaceToDepthModule 30 | from .split_attn import SplitAttnConv2d 31 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 32 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 33 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 34 | from .weight_init import trunc_normal_ 35 | -------------------------------------------------------------------------------- /timm/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /timm/models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = flatten 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(self.flatten) 91 | self.flatten = False 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return self.pool_type == '' 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | if self.flatten: 109 | x = x.flatten(1) 110 | return x 111 | 112 | def feat_mult(self): 113 | return adaptive_pool_feat_mult(self.pool_type) 114 | 115 | def __repr__(self): 116 | return self.__class__.__name__ + ' (' \ 117 | + 'pool_type=' + self.pool_type \ 118 | + ', flatten=' + str(self.flatten) + ')' 119 | 120 | -------------------------------------------------------------------------------- /timm/models/layers/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AntiAliasDownsampleLayer(nn.Module): 8 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): 9 | super(AntiAliasDownsampleLayer, self).__init__() 10 | if no_jit: 11 | self.op = Downsample(channels, filt_size, stride) 12 | else: 13 | self.op = DownsampleJIT(channels, filt_size, stride) 14 | 15 | # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls 16 | 17 | def forward(self, x): 18 | return self.op(x) 19 | 20 | 21 | @torch.jit.script 22 | class DownsampleJIT(object): 23 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): 24 | self.channels = channels 25 | self.stride = stride 26 | self.filt_size = filt_size 27 | assert self.filt_size == 3 28 | assert stride == 2 29 | self.filt = {} # lazy init by device for DataParallel compat 30 | 31 | def _create_filter(self, like: torch.Tensor): 32 | filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) 33 | filt = filt[:, None] * filt[None, :] 34 | filt = filt / torch.sum(filt) 35 | return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 36 | 37 | def __call__(self, input: torch.Tensor): 38 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 39 | filt = self.filt.get(str(input.device), self._create_filter(input)) 40 | return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) 41 | 42 | 43 | class Downsample(nn.Module): 44 | def __init__(self, channels=None, filt_size=3, stride=2): 45 | super(Downsample, self).__init__() 46 | self.channels = channels 47 | self.filt_size = filt_size 48 | self.stride = stride 49 | 50 | assert self.filt_size == 3 51 | filt = torch.tensor([1., 2., 1.]) 52 | filt = filt[:, None] * filt[None, :] 53 | filt = filt / torch.sum(filt) 54 | 55 | # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 56 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 57 | 58 | def forward(self, input): 59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) 61 | -------------------------------------------------------------------------------- /timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | FIXME merge this impl with those in `anti_aliasing.py` 7 | 8 | Hacked together by Chris Ha and Ross Wightman 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from typing import Dict 16 | from .padding import get_padding 17 | 18 | 19 | class BlurPool2d(nn.Module): 20 | r"""Creates a module that computes blurs and downsample a given feature map. 21 | See :cite:`zhang2019shiftinvar` for more details. 22 | Corresponds to the Downsample class, which does blurring and subsampling 23 | 24 | Args: 25 | channels = Number of input channels 26 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 27 | stride (int): downsampling filter stride 28 | 29 | Returns: 30 | torch.Tensor: the transformed tensor. 31 | """ 32 | filt: Dict[str, torch.Tensor] 33 | 34 | def __init__(self, channels, filt_size=3, stride=2) -> None: 35 | super(BlurPool2d, self).__init__() 36 | assert filt_size > 1 37 | self.channels = channels 38 | self.filt_size = filt_size 39 | self.stride = stride 40 | pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 41 | self.padding = nn.ReflectionPad2d(pad_size) 42 | self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat 43 | self.filt = {} # lazy init by device for DataParallel compat 44 | 45 | def _create_filter(self, like: torch.Tensor): 46 | blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) 47 | return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) 48 | 49 | def _apply(self, fn): 50 | # override nn.Module _apply, reset filter cache if used 51 | self.filt = {} 52 | super(BlurPool2d, self)._apply(fn) 53 | 54 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 55 | C = input_tensor.shape[1] 56 | blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) 57 | return F.conv2d( 58 | self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) 59 | -------------------------------------------------------------------------------- /timm/models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | import torch.nn.functional as F 14 | from .conv_bn_act import ConvBnAct 15 | 16 | 17 | class ChannelAttn(nn.Module): 18 | """ Original CBAM channel attention module, currently avg + max pool variant only. 19 | """ 20 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 21 | super(ChannelAttn, self).__init__() 22 | self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) 23 | self.act = act_layer(inplace=True) 24 | self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) 25 | 26 | def forward(self, x): 27 | x_avg = x.mean((2, 3), keepdim=True) 28 | x_max = F.adaptive_max_pool2d(x, 1) 29 | x_avg = self.fc2(self.act(self.fc1(x_avg))) 30 | x_max = self.fc2(self.act(self.fc1(x_max))) 31 | x_attn = x_avg + x_max 32 | return x * x_attn.sigmoid() 33 | 34 | 35 | class LightChannelAttn(ChannelAttn): 36 | """An experimental 'lightweight' that sums avg + max pool first 37 | """ 38 | def __init__(self, channels, reduction=16): 39 | super(LightChannelAttn, self).__init__(channels, reduction) 40 | 41 | def forward(self, x): 42 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) 43 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 44 | return x * x_attn.sigmoid() 45 | 46 | 47 | class SpatialAttn(nn.Module): 48 | """ Original CBAM spatial attention module 49 | """ 50 | def __init__(self, kernel_size=7): 51 | super(SpatialAttn, self).__init__() 52 | self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) 53 | 54 | def forward(self, x): 55 | x_avg = torch.mean(x, dim=1, keepdim=True) 56 | x_max = torch.max(x, dim=1, keepdim=True)[0] 57 | x_attn = torch.cat([x_avg, x_max], dim=1) 58 | x_attn = self.conv(x_attn) 59 | return x * x_attn.sigmoid() 60 | 61 | 62 | class LightSpatialAttn(nn.Module): 63 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 64 | """ 65 | def __init__(self, kernel_size=7): 66 | super(LightSpatialAttn, self).__init__() 67 | self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) 68 | 69 | def forward(self, x): 70 | x_avg = torch.mean(x, dim=1, keepdim=True) 71 | x_max = torch.max(x, dim=1, keepdim=True)[0] 72 | x_attn = 0.5 * x_avg + 0.5 * x_max 73 | x_attn = self.conv(x_attn) 74 | return x * x_attn.sigmoid() 75 | 76 | 77 | class CbamModule(nn.Module): 78 | def __init__(self, channels, spatial_kernel_size=7): 79 | super(CbamModule, self).__init__() 80 | self.channel = ChannelAttn(channels) 81 | self.spatial = SpatialAttn(spatial_kernel_size) 82 | 83 | def forward(self, x): 84 | x = self.channel(x) 85 | x = self.spatial(x) 86 | return x 87 | 88 | 89 | class LightCbamModule(nn.Module): 90 | def __init__(self, channels, spatial_kernel_size=7): 91 | super(LightCbamModule, self).__init__() 92 | self.channel = LightChannelAttn(channels) 93 | self.spatial = LightSpatialAttn(spatial_kernel_size) 94 | 95 | def forward(self, x): 96 | x = self.channel(x) 97 | x = self.spatial(x) 98 | return x 99 | 100 | -------------------------------------------------------------------------------- /timm/models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | from .linear import Linear 10 | 11 | 12 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 13 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 14 | if not pool_type: 15 | assert num_classes == 0 or use_conv,\ 16 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 17 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 18 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 19 | num_pooled_features = num_features * global_pool.feat_mult() 20 | return global_pool, num_pooled_features 21 | 22 | 23 | def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): 24 | if num_classes <= 0: 25 | fc = nn.Identity() # pass-through (no classifier) 26 | elif use_conv: 27 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 28 | else: 29 | # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue 30 | fc = Linear(num_features, num_classes, bias=True) 31 | return fc 32 | 33 | 34 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 35 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 36 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 37 | return global_pool, fc 38 | 39 | 40 | class ClassifierHead(nn.Module): 41 | """Classifier head w/ configurable global pooling and dropout.""" 42 | 43 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 44 | super(ClassifierHead, self).__init__() 45 | self.drop_rate = drop_rate 46 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 47 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 48 | self.flatten_after_fc = use_conv and pool_type 49 | 50 | def forward(self, x): 51 | x = self.global_pool(x) 52 | if self.drop_rate: 53 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 54 | x = self.fc(x) 55 | return x 56 | -------------------------------------------------------------------------------- /timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import convert_norm_act 9 | 10 | 11 | class ConvBnAct(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 13 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, 14 | drop_block=None): 15 | super(ConvBnAct, self).__init__() 16 | use_aa = aa_layer is not None 17 | 18 | self.conv = create_conv2d( 19 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 20 | padding=padding, dilation=dilation, groups=groups, bias=bias) 21 | 22 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 23 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 25 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | if self.aa is not None: 39 | x = self.aa(x) 40 | return x 41 | -------------------------------------------------------------------------------- /timm/models/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .activations import * 5 | from .activations_jit import * 6 | from .activations_me import * 7 | from .config import is_exportable, is_scriptable, is_no_jit 8 | 9 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code 10 | # will use native version if present. Eventually, the custom Swish layers will be removed 11 | # and only native 'silu' will be used. 12 | _has_silu = 'silu' in dir(torch.nn.functional) 13 | 14 | _ACT_FN_DEFAULT = dict( 15 | silu=F.silu if _has_silu else swish, 16 | swish=F.silu if _has_silu else swish, 17 | mish=mish, 18 | relu=F.relu, 19 | relu6=F.relu6, 20 | leaky_relu=F.leaky_relu, 21 | elu=F.elu, 22 | celu=F.celu, 23 | selu=F.selu, 24 | gelu=gelu, 25 | sigmoid=sigmoid, 26 | tanh=tanh, 27 | hard_sigmoid=hard_sigmoid, 28 | hard_swish=hard_swish, 29 | hard_mish=hard_mish, 30 | ) 31 | 32 | _ACT_FN_JIT = dict( 33 | silu=F.silu if _has_silu else swish_jit, 34 | swish=F.silu if _has_silu else swish_jit, 35 | mish=mish_jit, 36 | hard_sigmoid=hard_sigmoid_jit, 37 | hard_swish=hard_swish_jit, 38 | hard_mish=hard_mish_jit 39 | ) 40 | 41 | _ACT_FN_ME = dict( 42 | silu=F.silu if _has_silu else swish_me, 43 | swish=F.silu if _has_silu else swish_me, 44 | mish=mish_me, 45 | hard_sigmoid=hard_sigmoid_me, 46 | hard_swish=hard_swish_me, 47 | hard_mish=hard_mish_me, 48 | ) 49 | 50 | _ACT_LAYER_DEFAULT = dict( 51 | silu=nn.SiLU if _has_silu else Swish, 52 | swish=nn.SiLU if _has_silu else Swish, 53 | mish=Mish, 54 | relu=nn.ReLU, 55 | relu6=nn.ReLU6, 56 | leaky_relu=nn.LeakyReLU, 57 | elu=nn.ELU, 58 | prelu=PReLU, 59 | celu=nn.CELU, 60 | selu=nn.SELU, 61 | gelu=GELU, 62 | sigmoid=Sigmoid, 63 | tanh=Tanh, 64 | hard_sigmoid=HardSigmoid, 65 | hard_swish=HardSwish, 66 | hard_mish=HardMish, 67 | ) 68 | 69 | _ACT_LAYER_JIT = dict( 70 | silu=nn.SiLU if _has_silu else SwishJit, 71 | swish=nn.SiLU if _has_silu else SwishJit, 72 | mish=MishJit, 73 | hard_sigmoid=HardSigmoidJit, 74 | hard_swish=HardSwishJit, 75 | hard_mish=HardMishJit 76 | ) 77 | 78 | _ACT_LAYER_ME = dict( 79 | silu=nn.SiLU if _has_silu else SwishMe, 80 | swish=nn.SiLU if _has_silu else SwishMe, 81 | mish=MishMe, 82 | hard_sigmoid=HardSigmoidMe, 83 | hard_swish=HardSwishMe, 84 | hard_mish=HardMishMe, 85 | ) 86 | 87 | 88 | def get_act_fn(name='relu'): 89 | """ Activation Function Factory 90 | Fetching activation fns by name with this function allows export or torch script friendly 91 | functions to be returned dynamically based on current config. 92 | """ 93 | if not name: 94 | return None 95 | if not (is_no_jit() or is_exportable() or is_scriptable()): 96 | # If not exporting or scripting the model, first look for a memory-efficient version with 97 | # custom autograd, then fallback 98 | if name in _ACT_FN_ME: 99 | return _ACT_FN_ME[name] 100 | if is_exportable() and name in ('silu', 'swish'): 101 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 102 | return swish 103 | if not (is_no_jit() or is_exportable()): 104 | if name in _ACT_FN_JIT: 105 | return _ACT_FN_JIT[name] 106 | return _ACT_FN_DEFAULT[name] 107 | 108 | 109 | def get_act_layer(name='relu'): 110 | """ Activation Layer Factory 111 | Fetching activation layers by name with this function allows export or torch script friendly 112 | functions to be returned dynamically based on current config. 113 | """ 114 | if not name: 115 | return None 116 | if not (is_no_jit() or is_exportable() or is_scriptable()): 117 | if name in _ACT_LAYER_ME: 118 | return _ACT_LAYER_ME[name] 119 | if is_exportable() and name in ('silu', 'swish'): 120 | # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack 121 | return Swish 122 | if not (is_no_jit() or is_exportable()): 123 | if name in _ACT_LAYER_JIT: 124 | return _ACT_LAYER_JIT[name] 125 | return _ACT_LAYER_DEFAULT[name] 126 | 127 | 128 | def create_act_layer(name, inplace=False, **kwargs): 129 | act_layer = get_act_layer(name) 130 | if act_layer is not None: 131 | return act_layer(inplace=inplace, **kwargs) 132 | else: 133 | return None 134 | -------------------------------------------------------------------------------- /timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Select AttentionFactory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from .se import SEModule, EffectiveSEModule 7 | from .eca import EcaModule, CecaModule 8 | from .cbam import CbamModule, LightCbamModule 9 | 10 | 11 | def get_attn(attn_type): 12 | if isinstance(attn_type, torch.nn.Module): 13 | return attn_type 14 | module_cls = None 15 | if attn_type is not None: 16 | if isinstance(attn_type, str): 17 | attn_type = attn_type.lower() 18 | if attn_type == 'se': 19 | module_cls = SEModule 20 | elif attn_type == 'ese': 21 | module_cls = EffectiveSEModule 22 | elif attn_type == 'eca': 23 | module_cls = EcaModule 24 | elif attn_type == 'ceca': 25 | module_cls = CecaModule 26 | elif attn_type == 'cbam': 27 | module_cls = CbamModule 28 | elif attn_type == 'lcbam': 29 | module_cls = LightCbamModule 30 | else: 31 | assert False, "Invalid attn module (%s)" % attn_type 32 | elif isinstance(attn_type, bool): 33 | if attn_type: 34 | module_cls = SEModule 35 | else: 36 | module_cls = attn_type 37 | return module_cls 38 | 39 | 40 | def create_attn(attn_type, channels, **kwargs): 41 | module_cls = get_attn(attn_type) 42 | if module_cls is not None: 43 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 44 | return module_cls(channels, **kwargs) 45 | return None 46 | -------------------------------------------------------------------------------- /timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 26 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 27 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 28 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 29 | else: 30 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 31 | return m 32 | -------------------------------------------------------------------------------- /timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 16 | from .norm_act import BatchNormAct2d, GroupNormAct 17 | from .inplace_abn import InplaceAbn 18 | 19 | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} 20 | _NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type 21 | 22 | 23 | def get_norm_act_layer(layer_class): 24 | layer_class = layer_class.replace('_', '').lower() 25 | if layer_class.startswith("batchnorm"): 26 | layer = BatchNormAct2d 27 | elif layer_class.startswith("groupnorm"): 28 | layer = GroupNormAct 29 | elif layer_class == "evonormbatch": 30 | layer = EvoNormBatch2d 31 | elif layer_class == "evonormsample": 32 | layer = EvoNormSample2d 33 | elif layer_class == "iabn" or layer_class == "inplaceabn": 34 | layer = InplaceAbn 35 | else: 36 | assert False, "Invalid norm_act layer (%s)" % layer_class 37 | return layer 38 | 39 | 40 | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): 41 | layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu 42 | assert len(layer_parts) in (1, 2) 43 | layer = get_norm_act_layer(layer_parts[0]) 44 | #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? 45 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 46 | if jit: 47 | layer_instance = torch.jit.script(layer_instance) 48 | return layer_instance 49 | 50 | 51 | def convert_norm_act(norm_layer, act_layer): 52 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 53 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 54 | norm_act_kwargs = {} 55 | 56 | # unbind partial fn, so args can be rebound later 57 | if isinstance(norm_layer, functools.partial): 58 | norm_act_kwargs.update(norm_layer.keywords) 59 | norm_layer = norm_layer.func 60 | 61 | if isinstance(norm_layer, str): 62 | norm_act_layer = get_norm_act_layer(norm_layer) 63 | elif norm_layer in _NORM_ACT_TYPES: 64 | norm_act_layer = norm_layer 65 | elif isinstance(norm_layer, types.FunctionType): 66 | # if function type, must be a lambda/fn that creates a norm_act layer 67 | norm_act_layer = norm_layer 68 | else: 69 | type_name = norm_layer.__name__.lower() 70 | if type_name.startswith('batchnorm'): 71 | norm_act_layer = BatchNormAct2d 72 | elif type_name.startswith('groupnorm'): 73 | norm_act_layer = GroupNormAct 74 | else: 75 | assert False, f"No equivalent norm_act layer for {type_name}" 76 | 77 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 78 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 79 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 80 | norm_act_kwargs.setdefault('act_layer', act_layer) 81 | if norm_act_kwargs: 82 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 83 | return norm_act_layer 84 | -------------------------------------------------------------------------------- /timm/models/layers/evo_norm.py: -------------------------------------------------------------------------------- 1 | """EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch 2 | 3 | An attempt at getting decent performing EvoNorms running in PyTorch. 4 | While currently faster than other impl, still quite a ways off the built-in BN 5 | in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). 6 | 7 | Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class EvoNormBatch2d(nn.Module): 17 | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): 18 | super(EvoNormBatch2d, self).__init__() 19 | self.apply_act = apply_act # apply activation (non-linearity) 20 | self.momentum = momentum 21 | self.eps = eps 22 | param_shape = (1, num_features, 1, 1) 23 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 24 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 25 | if apply_act: 26 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 27 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.apply_act: 34 | nn.init.ones_(self.v) 35 | 36 | def forward(self, x): 37 | assert x.dim() == 4, 'expected 4D input' 38 | x_type = x.dtype 39 | if self.training: 40 | var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) 41 | n = x.numel() / x.shape[1] 42 | self.running_var.copy_( 43 | var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) 44 | else: 45 | var = self.running_var 46 | 47 | if self.apply_act: 48 | v = self.v.to(dtype=x_type) 49 | d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) 50 | d = d.max((var + self.eps).sqrt().to(dtype=x_type)) 51 | x = x / d 52 | return x * self.weight + self.bias 53 | 54 | 55 | class EvoNormSample2d(nn.Module): 56 | def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): 57 | super(EvoNormSample2d, self).__init__() 58 | self.apply_act = apply_act # apply activation (non-linearity) 59 | self.groups = groups 60 | self.eps = eps 61 | param_shape = (1, num_features, 1, 1) 62 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 63 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 64 | if apply_act: 65 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | nn.init.ones_(self.weight) 70 | nn.init.zeros_(self.bias) 71 | if self.apply_act: 72 | nn.init.ones_(self.v) 73 | 74 | def forward(self, x): 75 | assert x.dim() == 4, 'expected 4D input' 76 | B, C, H, W = x.shape 77 | assert C % self.groups == 0 78 | if self.apply_act: 79 | n = x * (x * self.v).sigmoid() 80 | x = x.reshape(B, self.groups, -1) 81 | x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() 82 | x = x.reshape(B, C, H, W) 83 | return x * self.weight + self.bias 84 | -------------------------------------------------------------------------------- /timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_block=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /timm/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/norm_act.py: -------------------------------------------------------------------------------- 1 | """ Normalization + Activation Layers 2 | """ 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .create_act import get_act_layer 8 | 9 | 10 | class BatchNormAct2d(nn.BatchNorm2d): 11 | """BatchNorm + Activation 12 | 13 | This module performs BatchNorm + Activation in a manner that will remain backwards 14 | compatible with weights trained with separate bn, act. This is why we inherit from BN 15 | instead of composing it as a .bn member. 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, 18 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 19 | super(BatchNormAct2d, self).__init__( 20 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 21 | if isinstance(act_layer, str): 22 | act_layer = get_act_layer(act_layer) 23 | if act_layer is not None and apply_act: 24 | act_args = dict(inplace=True) if inplace else {} 25 | self.act = act_layer(**act_args) 26 | else: 27 | self.act = nn.Identity() 28 | 29 | def _forward_jit(self, x): 30 | """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function 31 | """ 32 | # exponential_average_factor is self.momentum set to 33 | # (when it is available) only so that if gets updated 34 | # in ONNX graph when this node is exported to ONNX. 35 | if self.momentum is None: 36 | exponential_average_factor = 0.0 37 | else: 38 | exponential_average_factor = self.momentum 39 | 40 | if self.training and self.track_running_stats: 41 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 42 | if self.num_batches_tracked is not None: 43 | self.num_batches_tracked += 1 44 | if self.momentum is None: # use cumulative moving average 45 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 46 | else: # use exponential moving average 47 | exponential_average_factor = self.momentum 48 | 49 | x = F.batch_norm( 50 | x, self.running_mean, self.running_var, self.weight, self.bias, 51 | self.training or not self.track_running_stats, 52 | exponential_average_factor, self.eps) 53 | return x 54 | 55 | @torch.jit.ignore 56 | def _forward_python(self, x): 57 | return super(BatchNormAct2d, self).forward(x) 58 | 59 | def forward(self, x): 60 | # FIXME cannot call parent forward() and maintain jit.script compatibility? 61 | if torch.jit.is_scripting(): 62 | x = self._forward_jit(x) 63 | else: 64 | x = self._forward_python(x) 65 | x = self.act(x) 66 | return x 67 | 68 | 69 | class GroupNormAct(nn.GroupNorm): 70 | # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args 71 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, 72 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 73 | super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) 74 | if isinstance(act_layer, str): 75 | act_layer = get_act_layer(act_layer) 76 | if act_layer is not None and apply_act: 77 | act_args = dict(inplace=True) if inplace else {} 78 | self.act = act_layer(**act_args) 79 | else: 80 | self.act = nn.Identity() 81 | 82 | def forward(self, x): 83 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 84 | x = self.act(x) 85 | return x 86 | -------------------------------------------------------------------------------- /timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | return avg_pool2d_same( 31 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 32 | 33 | 34 | def max_pool2d_same( 35 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 36 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 37 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 38 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 39 | 40 | 41 | class MaxPool2dSame(nn.MaxPool2d): 42 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 43 | """ 44 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): 45 | kernel_size = to_2tuple(kernel_size) 46 | stride = to_2tuple(stride) 47 | dilation = to_2tuple(dilation) 48 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) 49 | 50 | def forward(self, x): 51 | return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) 52 | 53 | 54 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 55 | stride = stride or kernel_size 56 | padding = kwargs.pop('padding', '') 57 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 58 | if is_dynamic: 59 | if pool_type == 'avg': 60 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 61 | elif pool_type == 'max': 62 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 63 | else: 64 | assert False, f'Unsupported pool type {pool_type}' 65 | else: 66 | if pool_type == 'avg': 67 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 68 | elif pool_type == 'max': 69 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | else: 71 | assert False, f'Unsupported pool type {pool_type}' 72 | -------------------------------------------------------------------------------- /timm/models/layers/se.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .create_act import create_act_layer 5 | from .helpers import make_divisible 6 | 7 | 8 | class SEModule(nn.Module): 9 | """ SE Module as defined in original SE-Nets with a few additions 10 | Additions include: 11 | * min_channels can be specified to keep reduced channel count at a minimum (default: 8) 12 | * divisor can be specified to keep channels rounded to specified values (default: 1) 13 | * reduction channels can be specified directly by arg (if reduction_channels is set) 14 | * reduction channels can be specified by float ratio (if reduction_ratio is set) 15 | """ 16 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid', 17 | reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1): 18 | super(SEModule, self).__init__() 19 | if reduction_channels is not None: 20 | reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done 21 | elif reduction_ratio is not None: 22 | reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels) 23 | else: 24 | reduction_channels = make_divisible(channels // reduction, divisor, min_channels) 25 | self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) 26 | self.act = act_layer(inplace=True) 27 | self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) 28 | self.gate = create_act_layer(gate_layer) 29 | 30 | def forward(self, x): 31 | x_se = x.mean((2, 3), keepdim=True) 32 | x_se = self.fc1(x_se) 33 | x_se = self.act(x_se) 34 | x_se = self.fc2(x_se) 35 | return x * self.gate(x_se) 36 | 37 | 38 | class EffectiveSEModule(nn.Module): 39 | """ 'Effective Squeeze-Excitation 40 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 41 | """ 42 | def __init__(self, channels, gate_layer='hard_sigmoid'): 43 | super(EffectiveSEModule, self).__init__() 44 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 45 | self.gate = create_act_layer(gate_layer, inplace=True) 46 | 47 | def forward(self, x): 48 | x_se = x.mean((2, 3), keepdim=True) 49 | x_se = self.fc(x_se) 50 | return x * self.gate(x_se) 51 | -------------------------------------------------------------------------------- /timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import convert_norm_act 12 | 13 | 14 | class SeparableConvBnAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_block=None): 20 | super(SeparableConvBnAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = convert_norm_act(norm_layer, act_layer) 30 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) 31 | 32 | @property 33 | def in_channels(self): 34 | return self.conv_dw.in_channels 35 | 36 | @property 37 | def out_channels(self): 38 | return self.conv_pw.out_channels 39 | 40 | def forward(self, x): 41 | x = self.conv_dw(x) 42 | x = self.conv_pw(x) 43 | if self.bn is not None: 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | class SeparableConv2d(nn.Module): 49 | """ Separable Conv 50 | """ 51 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 52 | channel_multiplier=1.0, pw_kernel_size=1): 53 | super(SeparableConv2d, self).__init__() 54 | 55 | self.conv_dw = create_conv2d( 56 | in_channels, int(in_channels * channel_multiplier), kernel_size, 57 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 58 | 59 | self.conv_pw = create_conv2d( 60 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 61 | 62 | @property 63 | def in_channels(self): 64 | return self.conv_dw.in_channels 65 | 66 | @property 67 | def out_channels(self): 68 | return self.conv_pw.out_channels 69 | 70 | def forward(self, x): 71 | x = self.conv_dw(x) 72 | x = self.conv_pw(x) 73 | return x 74 | -------------------------------------------------------------------------------- /timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class RadixSoftmax(nn.Module): 15 | def __init__(self, radix, cardinality): 16 | super(RadixSoftmax, self).__init__() 17 | self.radix = radix 18 | self.cardinality = cardinality 19 | 20 | def forward(self, x): 21 | batch = x.size(0) 22 | if self.radix > 1: 23 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 24 | x = F.softmax(x, dim=1) 25 | x = x.reshape(batch, -1) 26 | else: 27 | x = torch.sigmoid(x) 28 | return x 29 | 30 | 31 | class SplitAttnConv2d(nn.Module): 32 | """Split-Attention Conv2d 33 | """ 34 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 35 | dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, 36 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 37 | super(SplitAttnConv2d, self).__init__() 38 | self.radix = radix 39 | self.drop_block = drop_block 40 | mid_chs = out_channels * radix 41 | attn_chs = max(in_channels * radix // reduction_factor, 32) 42 | 43 | self.conv = nn.Conv2d( 44 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 45 | groups=groups * radix, bias=bias, **kwargs) 46 | self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None 47 | self.act0 = act_layer(inplace=True) 48 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 49 | self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None 50 | self.act1 = act_layer(inplace=True) 51 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 52 | self.rsoftmax = RadixSoftmax(radix, groups) 53 | 54 | @property 55 | def in_channels(self): 56 | return self.conv.in_channels 57 | 58 | @property 59 | def out_channels(self): 60 | return self.fc1.out_channels 61 | 62 | def forward(self, x): 63 | x = self.conv(x) 64 | if self.bn0 is not None: 65 | x = self.bn0(x) 66 | if self.drop_block is not None: 67 | x = self.drop_block(x) 68 | x = self.act0(x) 69 | 70 | B, RC, H, W = x.shape 71 | if self.radix > 1: 72 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 73 | x_gap = x.sum(dim=1) 74 | else: 75 | x_gap = x 76 | x_gap = F.adaptive_avg_pool2d(x_gap, 1) 77 | x_gap = self.fc1(x_gap) 78 | if self.bn1 is not None: 79 | x_gap = self.bn1(x_gap) 80 | x_gap = self.act1(x_gap) 81 | x_attn = self.fc2(x_gap) 82 | 83 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 84 | if self.radix > 1: 85 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 86 | else: 87 | out = x * x_attn 88 | return out.contiguous() 89 | -------------------------------------------------------------------------------- /timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=True): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /timm/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | -------------------------------------------------------------------------------- /timm/models/pruned/ecaresnet50d_pruned.txt: -------------------------------------------------------------------------------- 1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | 10 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] 11 | 12 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 13 | _model_to_module = {} # mapping of model names to module names 14 | _model_entrypoints = {} # mapping of model names to entrypoint fns 15 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 16 | 17 | 18 | def register_model(fn): 19 | # lookup containing module 20 | mod = sys.modules[fn.__module__] 21 | module_name_split = fn.__module__.split('.') 22 | module_name = module_name_split[-1] if len(module_name_split) else '' 23 | 24 | # add model to __all__ in module 25 | model_name = fn.__name__ 26 | if hasattr(mod, '__all__'): 27 | mod.__all__.append(model_name) 28 | else: 29 | mod.__all__ = [model_name] 30 | 31 | # add entries to registry dict/sets 32 | _model_entrypoints[model_name] = fn 33 | _model_to_module[model_name] = module_name 34 | _module_to_models[module_name].add(model_name) 35 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 36 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 37 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 38 | # entrypoints or non-matching combos 39 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 40 | if has_pretrained: 41 | _model_has_pretrained.add(model_name) 42 | return fn 43 | 44 | 45 | def _natural_key(string_): 46 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 47 | 48 | 49 | def list_models(filter='', module='', pretrained=False, exclude_filters=''): 50 | """ Return list of available model names, sorted alphabetically 51 | 52 | Args: 53 | filter (str) - Wildcard filter string that works with fnmatch 54 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 55 | pretrained (bool) - Include only models with pretrained weights if True 56 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 57 | 58 | Example: 59 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 60 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 61 | """ 62 | if module: 63 | models = list(_module_to_models[module]) 64 | else: 65 | models = _model_entrypoints.keys() 66 | if filter: 67 | models = fnmatch.filter(models, filter) # include these models 68 | if exclude_filters: 69 | if not isinstance(exclude_filters, list): 70 | exclude_filters = [exclude_filters] 71 | for xf in exclude_filters: 72 | exclude_models = fnmatch.filter(models, xf) # exclude these models 73 | if len(exclude_models): 74 | models = set(models).difference(exclude_models) 75 | if pretrained: 76 | models = _model_has_pretrained.intersection(models) 77 | return list(sorted(models, key=_natural_key)) 78 | 79 | 80 | def is_model(model_name): 81 | """ Check if a model name exists 82 | """ 83 | return model_name in _model_entrypoints 84 | 85 | 86 | def model_entrypoint(model_name): 87 | """Fetch a model entrypoint for specified model name 88 | """ 89 | return _model_entrypoints[model_name] 90 | 91 | 92 | def list_modules(): 93 | """ Return list of module names that contain models / model entrypoints 94 | """ 95 | modules = _module_to_models.keys() 96 | return list(sorted(modules)) 97 | 98 | 99 | def is_model_in_modules(model_name, module_names): 100 | """Check if a model exists within a subset of modules 101 | Args: 102 | model_name (str) - name of model to check 103 | module_names (tuple, list, set) - names of modules to search in 104 | """ 105 | assert isinstance(module_names, (tuple, list, set)) 106 | return any(model_name in _module_to_models[n] for n in module_names) 107 | 108 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamp import AdamP 2 | from .adamw import AdamW 3 | from .adafactor import Adafactor 4 | from .adahessian import Adahessian 5 | from .lookahead import Lookahead 6 | from .nadam import Nadam 7 | from .novograd import NovoGrad 8 | from .nvnovograd import NvNovoGrad 9 | from .radam import RAdam 10 | from .rmsprop_tf import RMSpropTF 11 | from .sgdp import SGDP 12 | 13 | from .optim_factory import create_optimizer -------------------------------------------------------------------------------- /timm/optim/adamp_original_version_by_ross.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class AdamP(Optimizer): 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 18 | weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): 19 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 20 | delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) 21 | super(AdamP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | 65 | grad = p.grad.data 66 | beta1, beta2 = group['betas'] 67 | nesterov = group['nesterov'] 68 | 69 | state = self.state[p] 70 | 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | state['exp_avg'] = torch.zeros_like(p.data) 75 | state['exp_avg_sq'] = torch.zeros_like(p.data) 76 | 77 | # Adam 78 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 79 | 80 | state['step'] += 1 81 | bias_correction1 = 1 - beta1 ** state['step'] 82 | bias_correction2 = 1 - beta2 ** state['step'] 83 | 84 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 85 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 86 | 87 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 88 | step_size = group['lr'] / bias_correction1 89 | 90 | if nesterov: 91 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 92 | else: 93 | perturb = exp_avg / denom 94 | 95 | # Projection 96 | wd_ratio = 1 97 | if len(p.shape) > 1: 98 | perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 99 | 100 | # Weight decay 101 | if group['weight_decay'] > 0: 102 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) 103 | 104 | # Step 105 | p.data.add_(-step_size, perturb) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /timm/optim/centralization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from torch.optim.optimizer import Optimizer, required 3 | 4 | 5 | def centralized_gradient(x, use_gc=True, gc_conv_only=False): 6 | if use_gc: 7 | if gc_conv_only: 8 | if len(list(x.size())) > 3: 9 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 10 | else: 11 | if len(list(x.size())) > 1: 12 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 13 | return x 14 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from collections import defaultdict 10 | 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 47 | loss = self.base_optimizer.step(closure) 48 | for group in self.param_groups: 49 | group['lookahead_step'] += 1 50 | if group['lookahead_step'] % group['lookahead_k'] == 0: 51 | self.update_slow(group) 52 | return loss 53 | 54 | def state_dict(self): 55 | fast_state_dict = self.base_optimizer.state_dict() 56 | slow_state = { 57 | (id(k) if isinstance(k, torch.Tensor) else k): v 58 | for k, v in self.state.items() 59 | } 60 | fast_state = fast_state_dict['state'] 61 | param_groups = fast_state_dict['param_groups'] 62 | return { 63 | 'state': fast_state, 64 | 'slow_state': slow_state, 65 | 'param_groups': param_groups, 66 | } 67 | 68 | def load_state_dict(self, state_dict): 69 | fast_state_dict = { 70 | 'state': state_dict['state'], 71 | 'param_groups': state_dict['param_groups'], 72 | } 73 | self.base_optimizer.load_state_dict(fast_state_dict) 74 | 75 | # We want to restore the slow state, but share param_groups reference 76 | # with base_optimizer. This is a bit redundant but least code 77 | slow_state_new = False 78 | if 'slow_state' not in state_dict: 79 | print('Loading state_dict from optimizer without Lookahead applied.') 80 | state_dict['slow_state'] = defaultdict(dict) 81 | slow_state_new = True 82 | slow_state_dict = { 83 | 'state': state_dict['slow_state'], 84 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 85 | } 86 | super(Lookahead, self).load_state_dict(slow_state_dict) 87 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 88 | if slow_state_new: 89 | # reapply defaults to catch missing lookahead specific ones 90 | for name, default in self.defaults.items(): 91 | for group in self.param_groups: 92 | group.setdefault(name, default) 93 | -------------------------------------------------------------------------------- /timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /timm/optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | class SGDP(Optimizer): 17 | def __init__(self, params, lr=required, momentum=0, dampening=0, 18 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 19 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 20 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 21 | super(SGDP, self).__init__(params, defaults) 22 | 23 | def _channel_view(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | def _layer_view(self, x): 27 | return x.view(1, -1) 28 | 29 | def _cosine_similarity(self, x, y, eps, view_func): 30 | x = view_func(x) 31 | y = view_func(y) 32 | 33 | x_norm = x.norm(dim=1).add_(eps) 34 | y_norm = y.norm(dim=1).add_(eps) 35 | dot = (x * y).sum(dim=1) 36 | 37 | return dot.abs() / x_norm / y_norm 38 | 39 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 40 | wd = 1 41 | expand_size = [-1] + [1] * (len(p.shape) - 1) 42 | for view_func in [self._channel_view, self._layer_view]: 43 | 44 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 45 | 46 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 47 | p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) 48 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) 49 | wd = wd_ratio 50 | 51 | return perturb, wd 52 | 53 | return perturb, wd 54 | 55 | def step(self, closure=None): 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group in self.param_groups: 61 | weight_decay = group['weight_decay'] 62 | momentum = group['momentum'] 63 | dampening = group['dampening'] 64 | nesterov = group['nesterov'] 65 | 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['momentum'] = torch.zeros_like(p.data) 75 | 76 | # SGD 77 | buf = state['momentum'] 78 | buf.mul_(momentum).add_(1 - dampening, grad) 79 | if nesterov: 80 | d_p = grad + momentum * buf 81 | else: 82 | d_p = buf 83 | 84 | # Projection 85 | wd_ratio = 1 86 | if len(p.shape) > 1: 87 | d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 88 | 89 | # Weight decay 90 | if weight_decay != 0: 91 | p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 92 | 93 | # Step 94 | p.data.add_(-group['lr'], d_p) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /timm/scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class CosineLRScheduler(Scheduler): 19 | """ 20 | Cosine decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1608.03983. 22 | 23 | Inspiration from 24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | t_initial: int, 30 | t_mul: float = 1., 31 | lr_min: float = 0., 32 | decay_rate: float = 1., 33 | warmup_t=0, 34 | warmup_lr_init=0, 35 | warmup_prefix=False, 36 | cycle_limit=0, 37 | t_in_epochs=True, 38 | noise_range_t=None, 39 | noise_pct=0.67, 40 | noise_std=1.0, 41 | noise_seed=42, 42 | initialize=True) -> None: 43 | super().__init__( 44 | optimizer, param_group_field="lr", 45 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 46 | initialize=initialize) 47 | 48 | assert t_initial > 0 49 | assert lr_min >= 0 50 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 51 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 52 | "rate since t_initial = t_mul = eta_mul = 1.") 53 | self.t_initial = t_initial 54 | self.t_mul = t_mul 55 | self.lr_min = lr_min 56 | self.decay_rate = decay_rate 57 | self.cycle_limit = cycle_limit 58 | self.warmup_t = warmup_t 59 | self.warmup_lr_init = warmup_lr_init 60 | self.warmup_prefix = warmup_prefix 61 | self.t_in_epochs = t_in_epochs 62 | if self.warmup_t: 63 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 64 | super().update_groups(self.warmup_lr_init) 65 | else: 66 | self.warmup_steps = [1 for _ in self.base_values] 67 | 68 | def _get_lr(self, t): 69 | if t < self.warmup_t: 70 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 71 | else: 72 | if self.warmup_prefix: 73 | t = t - self.warmup_t 74 | 75 | if self.t_mul != 1: 76 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 77 | t_i = self.t_mul ** i * self.t_initial 78 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 79 | else: 80 | i = t // self.t_initial 81 | t_i = self.t_initial 82 | t_curr = t - (self.t_initial * i) 83 | 84 | gamma = self.decay_rate ** i 85 | lr_min = self.lr_min * gamma 86 | lr_max_values = [v * gamma for v in self.base_values] 87 | 88 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 89 | lrs = [ 90 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 91 | ] 92 | else: 93 | lrs = [self.lr_min for _ in self.base_values] 94 | 95 | return lrs 96 | 97 | def get_epoch_values(self, epoch: int): 98 | if self.t_in_epochs: 99 | return self._get_lr(epoch) 100 | else: 101 | return None 102 | 103 | def get_update_values(self, num_updates: int): 104 | if not self.t_in_epochs: 105 | return self._get_lr(num_updates) 106 | else: 107 | return None 108 | 109 | def get_cycle_length(self, cycles=0): 110 | if not cycles: 111 | cycles = self.cycle_limit 112 | cycles = max(1, cycles) 113 | if self.t_mul == 1.0: 114 | return self.t_initial * cycles 115 | else: 116 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 117 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | from .scheduler import Scheduler 10 | 11 | 12 | class PlateauLRScheduler(Scheduler): 13 | """Decay the LR by a factor every time the validation loss plateaus.""" 14 | 15 | def __init__(self, 16 | optimizer, 17 | decay_rate=0.1, 18 | patience_t=10, 19 | verbose=True, 20 | threshold=1e-4, 21 | cooldown_t=0, 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | lr_min=0, 25 | mode='max', 26 | noise_range_t=None, 27 | noise_type='normal', 28 | noise_pct=0.67, 29 | noise_std=1.0, 30 | noise_seed=None, 31 | initialize=True, 32 | ): 33 | super().__init__(optimizer, 'lr', initialize=initialize) 34 | 35 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 36 | self.optimizer, 37 | patience=patience_t, 38 | factor=decay_rate, 39 | verbose=verbose, 40 | threshold=threshold, 41 | cooldown=cooldown_t, 42 | mode=mode, 43 | min_lr=lr_min 44 | ) 45 | 46 | self.noise_range = noise_range_t 47 | self.noise_pct = noise_pct 48 | self.noise_type = noise_type 49 | self.noise_std = noise_std 50 | self.noise_seed = noise_seed if noise_seed is not None else 42 51 | self.warmup_t = warmup_t 52 | self.warmup_lr_init = warmup_lr_init 53 | if self.warmup_t: 54 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 55 | super().update_groups(self.warmup_lr_init) 56 | else: 57 | self.warmup_steps = [1 for _ in self.base_values] 58 | self.restore_lr = None 59 | 60 | def state_dict(self): 61 | return { 62 | 'best': self.lr_scheduler.best, 63 | 'last_epoch': self.lr_scheduler.last_epoch, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | self.lr_scheduler.best = state_dict['best'] 68 | if 'last_epoch' in state_dict: 69 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 70 | 71 | # override the base class step fn completely 72 | def step(self, epoch, metric=None): 73 | if epoch <= self.warmup_t: 74 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 75 | super().update_groups(lrs) 76 | else: 77 | if self.restore_lr is not None: 78 | # restore actual LR from before our last noise perturbation before stepping base 79 | for i, param_group in enumerate(self.optimizer.param_groups): 80 | param_group['lr'] = self.restore_lr[i] 81 | self.restore_lr = None 82 | 83 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 84 | 85 | if self.noise_range is not None: 86 | if isinstance(self.noise_range, (list, tuple)): 87 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 88 | else: 89 | apply_noise = epoch >= self.noise_range 90 | if apply_noise: 91 | self._apply_noise(epoch) 92 | 93 | def _apply_noise(self, epoch): 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + epoch) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | 105 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 106 | # stepping of base scheduler 107 | restore_lr = [] 108 | for i, param_group in enumerate(self.optimizer.param_groups): 109 | old_lr = float(param_group['lr']) 110 | restore_lr.append(old_lr) 111 | new_lr = old_lr + old_lr * noise 112 | param_group['lr'] = new_lr 113 | self.restore_lr = restore_lr 114 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .tanh_lr import TanhLRScheduler 6 | from .step_lr import StepLRScheduler 7 | from .plateau_lr import PlateauLRScheduler 8 | 9 | 10 | def create_scheduler(args, optimizer): 11 | num_epochs = args.epochs 12 | 13 | if getattr(args, 'lr_noise', None) is not None: 14 | lr_noise = getattr(args, 'lr_noise') 15 | if isinstance(lr_noise, (list, tuple)): 16 | noise_range = [n * num_epochs for n in lr_noise] 17 | if len(noise_range) == 1: 18 | noise_range = noise_range[0] 19 | else: 20 | noise_range = lr_noise * num_epochs 21 | else: 22 | noise_range = None 23 | 24 | lr_scheduler = None 25 | if args.sched == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_epochs, 29 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 30 | lr_min=args.min_lr, 31 | decay_rate=args.decay_rate, 32 | warmup_lr_init=args.warmup_lr, 33 | warmup_t=args.warmup_epochs, 34 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 35 | t_in_epochs=True, 36 | noise_range_t=noise_range, 37 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 38 | noise_std=getattr(args, 'lr_noise_std', 1.), 39 | noise_seed=getattr(args, 'seed', 42), 40 | ) 41 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 42 | elif args.sched == 'tanh': 43 | lr_scheduler = TanhLRScheduler( 44 | optimizer, 45 | t_initial=num_epochs, 46 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 47 | lr_min=args.min_lr, 48 | warmup_lr_init=args.warmup_lr, 49 | warmup_t=args.warmup_epochs, 50 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 51 | t_in_epochs=True, 52 | noise_range_t=noise_range, 53 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 54 | noise_std=getattr(args, 'lr_noise_std', 1.), 55 | noise_seed=getattr(args, 'seed', 42), 56 | ) 57 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 58 | elif args.sched == 'step': 59 | lr_scheduler = StepLRScheduler( 60 | optimizer, 61 | decay_t=args.decay_epochs, 62 | decay_rate=args.decay_rate, 63 | warmup_lr_init=args.warmup_lr, 64 | warmup_t=args.warmup_epochs, 65 | noise_range_t=noise_range, 66 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 67 | noise_std=getattr(args, 'lr_noise_std', 1.), 68 | noise_seed=getattr(args, 'seed', 42), 69 | ) 70 | elif args.sched == 'plateau': 71 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 72 | lr_scheduler = PlateauLRScheduler( 73 | optimizer, 74 | decay_rate=args.decay_rate, 75 | patience_t=args.patience_epochs, 76 | lr_min=args.min_lr, 77 | mode=mode, 78 | warmup_lr_init=args.warmup_lr, 79 | warmup_t=args.warmup_epochs, 80 | cooldown_t=0, 81 | noise_range_t=noise_range, 82 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 83 | noise_std=getattr(args, 'lr_noise_std', 1.), 84 | noise_seed=getattr(args, 'seed', 42), 85 | ) 86 | 87 | return lr_scheduler, num_epochs 88 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_t: float, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | t_in_epochs=True, 24 | noise_range_t=None, 25 | noise_pct=0.67, 26 | noise_std=1.0, 27 | noise_seed=42, 28 | initialize=True, 29 | ) -> None: 30 | super().__init__( 31 | optimizer, param_group_field="lr", 32 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 33 | initialize=initialize) 34 | 35 | self.decay_t = decay_t 36 | self.decay_rate = decay_rate 37 | self.warmup_t = warmup_t 38 | self.warmup_lr_init = warmup_lr_init 39 | self.t_in_epochs = t_in_epochs 40 | if self.warmup_t: 41 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 42 | super().update_groups(self.warmup_lr_init) 43 | else: 44 | self.warmup_steps = [1 for _ in self.base_values] 45 | 46 | def _get_lr(self, t): 47 | if t < self.warmup_t: 48 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 49 | else: 50 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 51 | return lrs 52 | 53 | def get_epoch_values(self, epoch: int): 54 | if self.t_in_epochs: 55 | return self._get_lr(epoch) 56 | else: 57 | return None 58 | 59 | def get_update_values(self, num_updates: int): 60 | if not self.t_in_epochs: 61 | return self._get_lr(num_updates) 62 | else: 63 | return None 64 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class TanhLRScheduler(Scheduler): 19 | """ 20 | Hyberbolic-Tangent decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1806.01593 22 | """ 23 | 24 | def __init__(self, 25 | optimizer: torch.optim.Optimizer, 26 | t_initial: int, 27 | lb: float = -6., 28 | ub: float = 4., 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | assert lb < ub 50 | assert cycle_limit >= 0 51 | assert warmup_t >= 0 52 | assert warmup_lr_init >= 0 53 | self.lb = lb 54 | self.ub = ub 55 | self.t_initial = t_initial 56 | self.t_mul = t_mul 57 | self.lr_min = lr_min 58 | self.decay_rate = decay_rate 59 | self.cycle_limit = cycle_limit 60 | self.warmup_t = warmup_t 61 | self.warmup_lr_init = warmup_lr_init 62 | self.warmup_prefix = warmup_prefix 63 | self.t_in_epochs = t_in_epochs 64 | if self.warmup_t: 65 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 66 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 67 | super().update_groups(self.warmup_lr_init) 68 | else: 69 | self.warmup_steps = [1 for _ in self.base_values] 70 | 71 | def _get_lr(self, t): 72 | if t < self.warmup_t: 73 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 74 | else: 75 | if self.warmup_prefix: 76 | t = t - self.warmup_t 77 | 78 | if self.t_mul != 1: 79 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 80 | t_i = self.t_mul ** i * self.t_initial 81 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 82 | else: 83 | i = t // self.t_initial 84 | t_i = self.t_initial 85 | t_curr = t - (self.t_initial * i) 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | gamma = self.decay_rate ** i 89 | lr_min = self.lr_min * gamma 90 | lr_max_values = [v * gamma for v in self.base_values] 91 | 92 | tr = t_curr / t_i 93 | lrs = [ 94 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 95 | for lr_max in lr_max_values 96 | ] 97 | else: 98 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 99 | return lrs 100 | 101 | def get_epoch_values(self, epoch: int): 102 | if self.t_in_epochs: 103 | return self._get_lr(epoch) 104 | else: 105 | return None 106 | 107 | def get_update_values(self, num_updates: int): 108 | if not self.t_in_epochs: 109 | return self._get_lr(num_updates) 110 | else: 111 | return None 112 | 113 | def get_cycle_length(self, cycles=0): 114 | if not cycles: 115 | cycles = self.cycle_limit 116 | cycles = max(1, cycles) 117 | if self.t_mul == 1.0: 118 | return self.t_initial * cycles 119 | else: 120 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 121 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .distributed import distribute_bn, reduce_tensor 6 | from .jit import set_jit_legacy 7 | from .log import setup_default_logging, FormatterNoInfo 8 | from .metrics import AverageMeter, accuracy 9 | from .misc import natural_key, add_bool_arg 10 | from .model import unwrap_model, get_state_dict 11 | from .model_ema import ModelEma, ModelEmaV2 12 | from .summary import update_summary, get_outdir 13 | -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 21 | with amp.scale_loss(loss, optimizer) as scaled_loss: 22 | scaled_loss.backward(create_graph=create_graph) 23 | if clip_grad is not None: 24 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 25 | optimizer.step() 26 | 27 | def state_dict(self): 28 | if 'state_dict' in amp.__dict__: 29 | return amp.state_dict() 30 | 31 | def load_state_dict(self, state_dict): 32 | if 'load_state_dict' in amp.__dict__: 33 | amp.load_state_dict(state_dict) 34 | 35 | 36 | class NativeScaler: 37 | state_dict_key = "amp_scaler" 38 | 39 | def __init__(self): 40 | self._scaler = torch.cuda.amp.GradScaler() 41 | 42 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 43 | self._scaler.scale(loss).backward(create_graph=create_graph) 44 | if clip_grad is not None: 45 | assert parameters is not None 46 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 47 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 48 | self._scaler.step(optimizer) 49 | self._scaler.update() 50 | 51 | def state_dict(self): 52 | return self._scaler.state_dict() 53 | 54 | def load_state_dict(self, state_dict): 55 | self._scaler.load_state_dict(state_dict) 56 | -------------------------------------------------------------------------------- /timm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Distributed training/validation utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | from torch import distributed as dist 7 | 8 | from .model import unwrap_model 9 | 10 | 11 | def reduce_tensor(tensor, n): 12 | rt = tensor.clone() 13 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 14 | rt /= n 15 | return rt 16 | 17 | 18 | def distribute_bn(model, world_size, reduce=False): 19 | # ensure every node has the same running bn stats 20 | for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): 21 | if ('running_mean' in bn_name) or ('running_var' in bn_name): 22 | if reduce: 23 | # average bn stats across whole group 24 | torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) 25 | bn_buf /= float(world_size) 26 | else: 27 | # broadcast bn stats from rank 0 to whole group 28 | torch.distributed.broadcast(bn_buf, 0) 29 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | 8 | def set_jit_legacy(): 9 | """ Set JIT executor to legacy w/ support for op fusion 10 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 11 | in the JIT exectutor. These API are not supported so could change. 12 | """ 13 | # 14 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 15 | torch._C._jit_set_profiling_executor(False) 16 | torch._C._jit_set_profiling_mode(False) 17 | torch._C._jit_override_can_fuse_on_gpu(True) 18 | #torch._C._jit_set_texpr_fuser_enabled(True) 19 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import re 6 | 7 | 8 | def natural_key(string_): 9 | """See http://www.codinghorror.com/blog/archives/001018.html""" 10 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 11 | 12 | 13 | def add_bool_arg(parser, name, default=False, help=''): 14 | dest_name = name.replace('-', '_') 15 | group = parser.add_mutually_exclusive_group(required=False) 16 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 17 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 18 | parser.set_defaults(**{dest_name: default}) 19 | -------------------------------------------------------------------------------- /timm/utils/model.py: -------------------------------------------------------------------------------- 1 | """ Model / state_dict utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from .model_ema import ModelEma 6 | 7 | 8 | def unwrap_model(model): 9 | if isinstance(model, ModelEma): 10 | return unwrap_model(model.ema) 11 | else: 12 | return model.module if hasattr(model, 'module') else model 13 | 14 | 15 | def get_state_dict(model, unwrap_fn=unwrap_model): 16 | return unwrap_fn(model).state_dict() 17 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | 5 | Modified by YANG Ruixin for outputting learning rate, SimCLR loss and classification loss 6 | 2021/09/07 7 | https://github.com/yang-ruixin 8 | yang_ruixin@126.com (in China) 9 | rxn.yang@gmail.com (out of China) 10 | """ 11 | import csv 12 | import os 13 | from collections import OrderedDict 14 | 15 | 16 | def get_outdir(path, *paths, inc=False): 17 | outdir = os.path.join(path, *paths) 18 | if not os.path.exists(outdir): 19 | os.makedirs(outdir) 20 | elif inc: 21 | count = 1 22 | outdir_inc = outdir + '-' + str(count) 23 | while os.path.exists(outdir_inc): 24 | count = count + 1 25 | outdir_inc = outdir + '-' + str(count) 26 | assert count < 100 27 | outdir = outdir_inc 28 | os.makedirs(outdir) 29 | return outdir 30 | 31 | 32 | def update_summary(epoch, train_metrics, eval_metrics, lr, simclr_loss, classification_loss, filename, write_header=False): 33 | rowd = OrderedDict(epoch=epoch) 34 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 35 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 36 | 37 | # ================================ 38 | rowd.update(OrderedDict(lr=lr)) 39 | rowd.update(OrderedDict(simclr_loss=simclr_loss)) 40 | rowd.update(OrderedDict(classification_loss=classification_loss)) 41 | # ================================ 42 | 43 | with open(filename, mode='a') as cf: 44 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 45 | if write_header: # first iteration (epoch == 1 can't be used) 46 | dw.writeheader() 47 | dw.writerow(rowd) 48 | -------------------------------------------------------------------------------- /timm/utils/summary_original_version_by_ross.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | 9 | 10 | def get_outdir(path, *paths, inc=False): 11 | outdir = os.path.join(path, *paths) 12 | if not os.path.exists(outdir): 13 | os.makedirs(outdir) 14 | elif inc: 15 | count = 1 16 | outdir_inc = outdir + '-' + str(count) 17 | while os.path.exists(outdir_inc): 18 | count = count + 1 19 | outdir_inc = outdir + '-' + str(count) 20 | assert count < 100 21 | outdir = outdir_inc 22 | os.makedirs(outdir) 23 | return outdir 24 | 25 | 26 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): 27 | rowd = OrderedDict(epoch=epoch) 28 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 29 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 30 | with open(filename, mode='a') as cf: 31 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 32 | if write_header: # first iteration (epoch == 1 can't be used) 33 | dw.writeheader() 34 | dw.writerow(rowd) 35 | -------------------------------------------------------------------------------- /timm/utils/summary_without_simclr.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | 5 | Modified by YANG Ruixin for outputting learning rate 6 | 2021/03/18 7 | https://github.com/yang-ruixin 8 | yang_ruixin@126.com (in China) 9 | rxn.yang@gmail.com (out of China) 10 | """ 11 | import csv 12 | import os 13 | from collections import OrderedDict 14 | 15 | 16 | def get_outdir(path, *paths, inc=False): 17 | outdir = os.path.join(path, *paths) 18 | if not os.path.exists(outdir): 19 | os.makedirs(outdir) 20 | elif inc: 21 | count = 1 22 | outdir_inc = outdir + '-' + str(count) 23 | while os.path.exists(outdir_inc): 24 | count = count + 1 25 | outdir_inc = outdir + '-' + str(count) 26 | assert count < 100 27 | outdir = outdir_inc 28 | os.makedirs(outdir) 29 | return outdir 30 | 31 | 32 | def update_summary(epoch, train_metrics, eval_metrics, lr, filename, write_header=False): 33 | rowd = OrderedDict(epoch=epoch) 34 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 35 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 36 | rowd.update(OrderedDict(lr=lr)) # ================================ 37 | with open(filename, mode='a') as cf: 38 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 39 | if write_header: # first iteration (epoch == 1 can't be used) 40 | dw.writeheader() 41 | dw.writerow(rowd) 42 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.4.4' 2 | --------------------------------------------------------------------------------