├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ └── feature_request.md └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── tests.yml │ ├── trufflehog.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── UPGRADING.md ├── avg_checkpoints.py ├── benchmark.py ├── bulk_runner.py ├── clean_checkpoint.py ├── convert ├── convert_from_mxnet.py └── convert_nest_flax.py ├── distributed_train.sh ├── hfdocs ├── README.md └── source │ ├── _toctree.yml │ ├── changes.mdx │ ├── feature_extraction.mdx │ ├── hf_hub.mdx │ ├── index.mdx │ ├── installation.mdx │ ├── models.mdx │ ├── models │ ├── adversarial-inception-v3.mdx │ ├── advprop.mdx │ ├── big-transfer.mdx │ ├── csp-darknet.mdx │ ├── csp-resnet.mdx │ ├── csp-resnext.mdx │ ├── densenet.mdx │ ├── dla.mdx │ ├── dpn.mdx │ ├── ecaresnet.mdx │ ├── efficientnet-pruned.mdx │ ├── efficientnet.mdx │ ├── ensemble-adversarial.mdx │ ├── ese-vovnet.mdx │ ├── fbnet.mdx │ ├── gloun-inception-v3.mdx │ ├── gloun-resnet.mdx │ ├── gloun-resnext.mdx │ ├── gloun-senet.mdx │ ├── gloun-seresnext.mdx │ ├── gloun-xception.mdx │ ├── hrnet.mdx │ ├── ig-resnext.mdx │ ├── inception-resnet-v2.mdx │ ├── inception-v3.mdx │ ├── inception-v4.mdx │ ├── legacy-se-resnet.mdx │ ├── legacy-se-resnext.mdx │ ├── legacy-senet.mdx │ ├── mixnet.mdx │ ├── mnasnet.mdx │ ├── mobilenet-v2.mdx │ ├── mobilenet-v3.mdx │ ├── nasnet.mdx │ ├── noisy-student.mdx │ ├── pnasnet.mdx │ ├── regnetx.mdx │ ├── regnety.mdx │ ├── res2net.mdx │ ├── res2next.mdx │ ├── resnest.mdx │ ├── resnet-d.mdx │ ├── resnet.mdx │ ├── resnext.mdx │ ├── rexnet.mdx │ ├── se-resnet.mdx │ ├── selecsls.mdx │ ├── seresnext.mdx │ ├── skresnet.mdx │ ├── skresnext.mdx │ ├── spnasnet.mdx │ ├── ssl-resnet.mdx │ ├── swsl-resnet.mdx │ ├── swsl-resnext.mdx │ ├── tf-efficientnet-condconv.mdx │ ├── tf-efficientnet-lite.mdx │ ├── tf-efficientnet.mdx │ ├── tf-inception-v3.mdx │ ├── tf-mixnet.mdx │ ├── tf-mobilenet-v3.mdx │ ├── tresnet.mdx │ ├── wide-resnet.mdx │ └── xception.mdx │ ├── quickstart.mdx │ ├── reference │ ├── data.mdx │ ├── models.mdx │ ├── optimizers.mdx │ └── schedulers.mdx │ ├── results.mdx │ └── training_script.mdx ├── hubconf.py ├── inference.py ├── onnx_export.py ├── onnx_validate.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── results ├── README.md ├── benchmark-infer-amp-nchw-pt113-cu117-rtx3090.csv ├── benchmark-infer-amp-nchw-pt210-cu121-rtx3090.csv ├── benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv ├── benchmark-infer-amp-nchw-pt240-cu124-rtx4090-dynamo.csv ├── benchmark-infer-amp-nchw-pt240-cu124-rtx4090.csv ├── benchmark-infer-amp-nhwc-pt113-cu117-rtx3090.csv ├── benchmark-infer-amp-nhwc-pt210-cu121-rtx3090.csv ├── benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv ├── benchmark-infer-amp-nhwc-pt240-cu124-rtx4090.csv ├── benchmark-infer-fp32-nchw-pt221-cpu-i9_10940x-dynamo.csv ├── benchmark-infer-fp32-nchw-pt240-cpu-i7_12700h-dynamo.csv ├── benchmark-infer-fp32-nchw-pt240-cpu-i9_10940x-dynamo.csv ├── benchmark-train-amp-nchw-pt112-cu113-rtx3090.csv ├── benchmark-train-amp-nhwc-pt112-cu113-rtx3090.csv ├── generate_csv_results.py ├── model_metadata-in1k.csv ├── results-imagenet-a-clean.csv ├── results-imagenet-a.csv ├── results-imagenet-r-clean.csv ├── results-imagenet-r.csv ├── results-imagenet-real.csv ├── results-imagenet.csv ├── results-imagenetv2-matched-frequency.csv └── results-sketch.csv ├── setup.cfg ├── tests ├── __init__.py ├── test_layers.py ├── test_models.py ├── test_optim.py └── test_utils.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── _info │ │ ├── imagenet12k_synsets.txt │ │ ├── imagenet21k_goog_synsets.txt │ │ ├── imagenet21k_goog_to_12k_indices.txt │ │ ├── imagenet21k_goog_to_22k_indices.txt │ │ ├── imagenet21k_miil_synsets.txt │ │ ├── imagenet21k_miil_w21_synsets.txt │ │ ├── imagenet22k_ms_synsets.txt │ │ ├── imagenet22k_ms_to_12k_indices.txt │ │ ├── imagenet22k_ms_to_22k_indices.txt │ │ ├── imagenet22k_synsets.txt │ │ ├── imagenet22k_to_12k_indices.txt │ │ ├── imagenet_a_indices.txt │ │ ├── imagenet_a_synsets.txt │ │ ├── imagenet_r_indices.txt │ │ ├── imagenet_r_synsets.txt │ │ ├── imagenet_real_labels.json │ │ ├── imagenet_synset_to_definition.txt │ │ ├── imagenet_synset_to_lemma.txt │ │ ├── imagenet_synsets.txt │ │ ├── mini_imagenet_indices.txt │ │ └── mini_imagenet_synsets.txt │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_factory.py │ ├── dataset_info.py │ ├── distributed_sampler.py │ ├── imagenet_info.py │ ├── loader.py │ ├── mixup.py │ ├── random_erasing.py │ ├── readers │ │ ├── __init__.py │ │ ├── class_map.py │ │ ├── img_extensions.py │ │ ├── reader.py │ │ ├── reader_factory.py │ │ ├── reader_hfds.py │ │ ├── reader_hfids.py │ │ ├── reader_image_folder.py │ │ ├── reader_image_in_tar.py │ │ ├── reader_image_tar.py │ │ ├── reader_tfds.py │ │ ├── reader_wds.py │ │ └── shared_count.py │ ├── real_labels.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── layers │ ├── __init__.py │ ├── activations.py │ ├── activations_me.py │ ├── adaptive_avgmax_pool.py │ ├── attention.py │ ├── attention2d.py │ ├── attention_pool.py │ ├── attention_pool2d.py │ ├── blur_pool.py │ ├── bottleneck_attn.py │ ├── cbam.py │ ├── classifier.py │ ├── cond_conv2d.py │ ├── config.py │ ├── conv2d_same.py │ ├── conv_bn_act.py │ ├── create_act.py │ ├── create_attn.py │ ├── create_conv2d.py │ ├── create_norm.py │ ├── create_norm_act.py │ ├── drop.py │ ├── eca.py │ ├── evo_norm.py │ ├── fast_norm.py │ ├── filter_response_norm.py │ ├── format.py │ ├── gather_excite.py │ ├── global_context.py │ ├── grid.py │ ├── grn.py │ ├── halo_attn.py │ ├── helpers.py │ ├── hybrid_embed.py │ ├── inplace_abn.py │ ├── interpolate.py │ ├── lambda_layer.py │ ├── layer_scale.py │ ├── linear.py │ ├── median_pool.py │ ├── mixed_conv2d.py │ ├── ml_decoder.py │ ├── mlp.py │ ├── non_local_attn.py │ ├── norm.py │ ├── norm_act.py │ ├── padding.py │ ├── patch_dropout.py │ ├── patch_embed.py │ ├── pool1d.py │ ├── pool2d_same.py │ ├── pos_embed.py │ ├── pos_embed_rel.py │ ├── pos_embed_sincos.py │ ├── selective_kernel.py │ ├── separable_conv.py │ ├── space_to_depth.py │ ├── split_attn.py │ ├── split_batchnorm.py │ ├── squeeze_excite.py │ ├── std_conv.py │ ├── test_time_pool.py │ ├── trace_utils.py │ ├── typing.py │ └── weight_init.py ├── loss │ ├── __init__.py │ ├── asymmetric_loss.py │ ├── binary_cross_entropy.py │ ├── cross_entropy.py │ └── jsd.py ├── models │ ├── __init__.py │ ├── _builder.py │ ├── _efficientnet_blocks.py │ ├── _efficientnet_builder.py │ ├── _factory.py │ ├── _features.py │ ├── _features_fx.py │ ├── _helpers.py │ ├── _hub.py │ ├── _manipulate.py │ ├── _pretrained.py │ ├── _prune.py │ ├── _pruned │ │ ├── ecaresnet101d_pruned.txt │ │ ├── ecaresnet50d_pruned.txt │ │ ├── efficientnet_b1_pruned.txt │ │ ├── efficientnet_b2_pruned.txt │ │ └── efficientnet_b3_pruned.txt │ ├── _registry.py │ ├── beit.py │ ├── byoanet.py │ ├── byobnet.py │ ├── cait.py │ ├── coat.py │ ├── convit.py │ ├── convmixer.py │ ├── convnext.py │ ├── crossvit.py │ ├── cspnet.py │ ├── davit.py │ ├── deit.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── edgenext.py │ ├── efficientformer.py │ ├── efficientformer_v2.py │ ├── efficientnet.py │ ├── efficientvit_mit.py │ ├── efficientvit_msra.py │ ├── eva.py │ ├── factory.py │ ├── fasternet.py │ ├── fastvit.py │ ├── features.py │ ├── focalnet.py │ ├── fx_features.py │ ├── gcvit.py │ ├── ghostnet.py │ ├── hardcorenas.py │ ├── helpers.py │ ├── hgnet.py │ ├── hiera.py │ ├── hieradet_sam2.py │ ├── hrnet.py │ ├── hub.py │ ├── inception_next.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ └── __init__.py │ ├── levit.py │ ├── mambaout.py │ ├── maxxvit.py │ ├── metaformer.py │ ├── mlp_mixer.py │ ├── mobilenetv3.py │ ├── mobilevit.py │ ├── mvitv2.py │ ├── nasnet.py │ ├── nest.py │ ├── nextvit.py │ ├── nfnet.py │ ├── pit.py │ ├── pnasnet.py │ ├── pvt_v2.py │ ├── rdnet.py │ ├── registry.py │ ├── regnet.py │ ├── repghost.py │ ├── repvit.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── resnetv2.py │ ├── rexnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sequencer.py │ ├── shvit.py │ ├── sknet.py │ ├── starnet.py │ ├── swiftformer.py │ ├── swin_transformer.py │ ├── swin_transformer_v2.py │ ├── swin_transformer_v2_cr.py │ ├── tiny_vit.py │ ├── tnt.py │ ├── tresnet.py │ ├── twins.py │ ├── vgg.py │ ├── visformer.py │ ├── vision_transformer.py │ ├── vision_transformer_hybrid.py │ ├── vision_transformer_relpos.py │ ├── vision_transformer_sam.py │ ├── vitamin.py │ ├── volo.py │ ├── vovnet.py │ ├── xception.py │ ├── xception_aligned.py │ └── xcit.py ├── optim │ ├── __init__.py │ ├── _optim_factory.py │ ├── _param_groups.py │ ├── _types.py │ ├── adabelief.py │ ├── adafactor.py │ ├── adafactor_bv.py │ ├── adahessian.py │ ├── adamp.py │ ├── adamw.py │ ├── adan.py │ ├── adopt.py │ ├── kron.py │ ├── lamb.py │ ├── laprop.py │ ├── lars.py │ ├── lion.py │ ├── lookahead.py │ ├── madgrad.py │ ├── mars.py │ ├── nadam.py │ ├── nadamw.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ ├── rmsprop_tf.py │ ├── sgdp.py │ └── sgdw.py ├── py.typed ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── multistep_lr.py │ ├── plateau_lr.py │ ├── poly_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils │ ├── __init__.py │ ├── agc.py │ ├── attention_extract.py │ ├── checkpoint_saver.py │ ├── clip_grad.py │ ├── cuda.py │ ├── decay_batch.py │ ├── distributed.py │ ├── jit.py │ ├── log.py │ ├── metrics.py │ ├── misc.py │ ├── model.py │ ├── model_ema.py │ ├── onnx.py │ ├── random.py │ └── summary.py └── version.py ├── train.py └── validate.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: rwightman 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting 4 | features, the discussion forum is available for asking questions or seeking help 5 | from the community. 6 | title: "[BUG] Issue title..." 7 | labels: bug 8 | assignees: rwightman 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. 18 | 2. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Windows 10, Ubuntu 18.04] 28 | - This repository version [e.g. pip 0.3.1 or commit ref] 29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Discussions 4 | url: https://github.com/rwightman/pytorch-image-models/discussions 5 | about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. Hparam requests, training help are not feature requests. 4 | The discussion forum is available for asking questions or seeking help from the community. 5 | title: "[FEATURE] Feature title..." 6 | labels: enhancement 7 | assignees: '' 8 | 9 | --- 10 | 11 | **Is your feature request related to a problem? Please describe.** 12 | A clear and concise description of what the problem is. 13 | 14 | **Describe the solution you'd like** 15 | A clear and concise description of what you want to happen. 16 | 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | 20 | **Additional context** 21 | Add any other context or screenshots about the feature request here. 22 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: pytorch-image-models 16 | package_name: timm 17 | path_to_docs: pytorch-image-models/hfdocs/source 18 | version_tag_suffix: "" 19 | secrets: 20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 21 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: pytorch-image-models 17 | package_name: timm 18 | path_to_docs: pytorch-image-models/hfdocs/source 19 | version_tag_suffix: "" 20 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | OMP_NUM_THREADS: 2 11 | MKL_NUM_THREADS: 2 12 | 13 | jobs: 14 | test: 15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest] 19 | python: ['3.10', '3.12'] 20 | torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.5.1', vision: '0.20.1'}] 21 | testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] 22 | exclude: 23 | - python: '3.12' 24 | torch: {base: '1.13.0', vision: '0.14.0'} 25 | runs-on: ${{ matrix.os }} 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | - name: Set up Python ${{ matrix.python }} 30 | uses: actions/setup-python@v1 31 | with: 32 | python-version: ${{ matrix.python }} 33 | - name: Install testing dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install -r requirements-dev.txt 37 | - name: Install torch on mac 38 | if: startsWith(matrix.os, 'macOS') 39 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 40 | - name: Install torch on Windows 41 | if: startsWith(matrix.os, 'windows') 42 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 43 | - name: Install torch on ubuntu 44 | if: startsWith(matrix.os, 'ubuntu') 45 | run: | 46 | sudo sed -i 's/azure\.//' /etc/apt/sources.list 47 | sudo apt update 48 | sudo apt install -y google-perftools 49 | pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu --index-url https://download.pytorch.org/whl/cpu 50 | - name: Install requirements 51 | run: | 52 | pip install -r requirements.txt 53 | - name: Force old numpy for old torch 54 | if: ${{ matrix.torch.base == '1.13.0' }} 55 | run: pip install --upgrade 'numpy<2.0' 56 | - name: Run tests on Windows 57 | if: startsWith(matrix.os, 'windows') 58 | env: 59 | PYTHONDONTWRITEBYTECODE: 1 60 | run: | 61 | pytest -vv tests 62 | - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac 63 | if: ${{ !startsWith(matrix.os, 'windows') }} 64 | env: 65 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 66 | PYTHONDONTWRITEBYTECODE: 1 67 | run: | 68 | pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests 69 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main 16 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: timm 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | output/ 104 | 105 | # PyTorch weights 106 | *.tar 107 | *.pth 108 | *.pt 109 | *.torch 110 | *.gz 111 | Untitled.ipynb 112 | Testing notebook.ipynb 113 | 114 | # Root dir exclusions 115 | /*.csv 116 | /*.yaml 117 | /*.json 118 | /*.jpg 119 | /*.png 120 | /*.zip 121 | /*.tar.* -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | message: "If you use this software, please cite it as below." 2 | title: "PyTorch Image Models" 3 | version: "1.2.2" 4 | doi: "10.5281/zenodo.4414861" 5 | authors: 6 | - family-names: Wightman 7 | given-names: Ross 8 | version: 1.0.11 9 | year: "2019" 10 | url: "https://github.com/huggingface/pytorch-image-models" 11 | license: "Apache 2.0" -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include timm/models/_pruned/*.txt 2 | include timm/data/_info/*.txt 3 | include timm/data/_info/*.json 4 | -------------------------------------------------------------------------------- /UPGRADING.md: -------------------------------------------------------------------------------- 1 | # Upgrading from previous versions 2 | 3 | I generally try to maintain code interface and especially model weight compatibility across many `timm` versions. Sometimes there are exceptions. 4 | 5 | ## Checkpoint remapping 6 | 7 | Pretrained weight remapping is handled by `checkpoint_filter_fn` in a model implementation module. This remaps old pretrained checkpoints to new, and also 3rd party (original) checkpoints to `timm` format if the model was modified when brought into `timm`. 8 | 9 | The `checkpoint_filter_fn` is automatically called when loading pretrained weights via `pretrained=True`, but they can be called manually if you call the fn directly with the current model instance and old state dict. 10 | 11 | ## Upgrading from 0.6 and earlier 12 | 13 | Many changes were made since the 0.6.x stable releases. They were previewed in 0.8.x dev releases but not everyone transitioned. 14 | * `timm.models.layers` moved to `timm.layers`: 15 | * `from timm.models.layers import name` will still work via deprecation mapping (but please transition to `timm.layers`). 16 | * `import timm.models.layers.module` or `from timm.models.layers.module import name` needs to be changed now. 17 | * Builder, helper, non-model modules in `timm.models` have a `_` prefix added, ie `timm.models.helpers` -> `timm.models._helpers`, there are temporary deprecation mapping files but those will be removed. 18 | * All models now support `architecture.pretrained_tag` naming (ex `resnet50.rsb_a1`). 19 | * The pretrained_tag is the specific weight variant (different head) for the architecture. 20 | * Using only `architecture` defaults to the first weights in the default_cfgs for that model architecture. 21 | * In adding pretrained tags, many model names that existed to differentiate were renamed to use the tag (ex: `vit_base_patch16_224_in21k` -> `vit_base_patch16_224.augreg_in21k`). There are deprecation mappings for these. 22 | * A number of models had their checkpoints remapped to match architecture changes needed to better support `features_only=True`, there are `checkpoint_filter_fn` methods in any model module that was remapped. These can be passed to `timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)` to remap your existing checkpoint. 23 | * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license. 24 | * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version. 25 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /hfdocs/README.md: -------------------------------------------------------------------------------- 1 | # Hugging Face Timm Docs 2 | 3 | ## Getting Started 4 | 5 | ``` 6 | pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder 7 | pip install watchdog black 8 | ``` 9 | 10 | ## Preview the Docs Locally 11 | 12 | ``` 13 | doc-builder preview timm hfdocs/source 14 | ``` 15 | -------------------------------------------------------------------------------- /hfdocs/source/hf_hub.mdx: -------------------------------------------------------------------------------- 1 | # Sharing and Loading Models From the Hugging Face Hub 2 | 3 | The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub. 4 | 5 | In this short guide, we'll see how to: 6 | 1. Share a `timm` model on the Hub 7 | 2. How to load that model back from the Hub 8 | 9 | ## Authenticating 10 | 11 | First, you'll need to make sure you have the `huggingface_hub` package installed. 12 | 13 | ```bash 14 | pip install huggingface_hub 15 | ``` 16 | 17 | Then, you'll need to authenticate yourself. You can do this by running the following command: 18 | 19 | ```bash 20 | huggingface-cli login 21 | ``` 22 | 23 | Or, if you're using a notebook, you can use the `notebook_login` helper: 24 | 25 | ```py 26 | >>> from huggingface_hub import notebook_login 27 | >>> notebook_login() 28 | ``` 29 | 30 | ## Sharing a Model 31 | 32 | ```py 33 | >>> import timm 34 | >>> model = timm.create_model('resnet18', pretrained=True, num_classes=4) 35 | ``` 36 | 37 | Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial. 38 | 39 | Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function. 40 | 41 | ```py 42 | >>> model_cfg = dict(label_names=['a', 'b', 'c', 'd']) 43 | >>> timm.models.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg) 44 | ``` 45 | 46 | Running the above would push the model to `/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code! 47 | 48 | ## Loading a Model 49 | 50 | Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub. 51 | 52 | ```py 53 | >>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True) 54 | ``` 55 | -------------------------------------------------------------------------------- /hfdocs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # timm 2 | 3 | 4 | 5 | `timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts. 6 | 7 | It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use. 8 | 9 | Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library. 10 | 11 |
12 | 22 |
23 | -------------------------------------------------------------------------------- /hfdocs/source/installation.mdx: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**. 4 | 5 | ## Virtual Environment 6 | 7 | You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts. 8 | 9 | 1. Create and navigate to your project directory: 10 | 11 | ```bash 12 | mkdir ~/my-project 13 | cd ~/my-project 14 | ``` 15 | 16 | 2. Start a virtual environment inside your directory: 17 | 18 | ```bash 19 | python -m venv .env 20 | ``` 21 | 22 | 3. Activate and deactivate the virtual environment with the following commands: 23 | 24 | ```bash 25 | # Activate the virtual environment 26 | source .env/bin/activate 27 | 28 | # Deactivate the virtual environment 29 | source .env/bin/deactivate 30 | ``` 31 | 32 | Once you've created your virtual environment, you can install `timm` in it. 33 | 34 | ## Using pip 35 | 36 | The most straightforward way to install `timm` is with pip: 37 | 38 | ```bash 39 | pip install timm 40 | ``` 41 | 42 | Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version: 43 | 44 | ```bash 45 | pip install git+https://github.com/rwightman/pytorch-image-models.git 46 | ``` 47 | 48 | Run the following command to check if `timm` has been properly installed: 49 | 50 | ```bash 51 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 52 | ``` 53 | 54 | This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output: 55 | 56 | ```python 57 | ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384'] 58 | ``` 59 | 60 | ## From Source 61 | 62 | Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands: 63 | 64 | ```bash 65 | git clone https://github.com/rwightman/pytorch-image-models.git 66 | cd pytorch-image-models 67 | pip install -e . 68 | ``` 69 | 70 | Again, you can check if `timm` was properly installed with the following command: 71 | 72 | ```bash 73 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 74 | ``` 75 | -------------------------------------------------------------------------------- /hfdocs/source/reference/data.mdx: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | [[autodoc]] timm.data.create_dataset 4 | 5 | [[autodoc]] timm.data.create_loader 6 | 7 | [[autodoc]] timm.data.create_transform 8 | 9 | [[autodoc]] timm.data.resolve_data_config -------------------------------------------------------------------------------- /hfdocs/source/reference/models.mdx: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | [[autodoc]] timm.create_model 4 | 5 | [[autodoc]] timm.list_models 6 | -------------------------------------------------------------------------------- /hfdocs/source/reference/optimizers.mdx: -------------------------------------------------------------------------------- 1 | # Optimization 2 | 3 | This page contains the API reference documentation for learning rate optimizers included in `timm`. 4 | 5 | ## Optimizers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.optim.create_optimizer_v2 10 | [[autodoc]] timm.optim.list_optimizers 11 | [[autodoc]] timm.optim.get_optimizer_class 12 | 13 | ### Optimizer Classes 14 | 15 | [[autodoc]] timm.optim.adabelief.AdaBelief 16 | [[autodoc]] timm.optim.adafactor.Adafactor 17 | [[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision 18 | [[autodoc]] timm.optim.adahessian.Adahessian 19 | [[autodoc]] timm.optim.adamp.AdamP 20 | [[autodoc]] timm.optim.adan.Adan 21 | [[autodoc]] timm.optim.adopt.Adopt 22 | [[autodoc]] timm.optim.lamb.Lamb 23 | [[autodoc]] timm.optim.laprop.LaProp 24 | [[autodoc]] timm.optim.lars.Lars 25 | [[autodoc]] timm.optim.lion.Lion 26 | [[autodoc]] timm.optim.lookahead.Lookahead 27 | [[autodoc]] timm.optim.madgrad.MADGRAD 28 | [[autodoc]] timm.optim.mars.Mars 29 | [[autodoc]] timm.optim.nadamw.NAdamW 30 | [[autodoc]] timm.optim.nvnovograd.NvNovoGrad 31 | [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF 32 | [[autodoc]] timm.optim.sgdp.SGDP 33 | [[autodoc]] timm.optim.sgdw.SGDW -------------------------------------------------------------------------------- /hfdocs/source/reference/schedulers.mdx: -------------------------------------------------------------------------------- 1 | # Learning Rate Schedulers 2 | 3 | This page contains the API reference documentation for learning rate schedulers included in `timm`. 4 | 5 | ## Schedulers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler 10 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2 11 | 12 | ### Scheduler Classes 13 | 14 | [[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler 15 | [[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler 16 | [[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler 17 | [[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler 18 | [[autodoc]] timm.scheduler.step_lr.StepLRScheduler 19 | [[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler 20 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | import timm 3 | globals().update(timm.models._registry._model_entrypoints) 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "timm" 7 | authors = [ 8 | {name = "Ross Wightman", email = "ross@huggingface.co"}, 9 | ] 10 | description = "PyTorch Image Models" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | keywords = ["pytorch", "image-classification"] 14 | license = {text = "Apache-2.0"} 15 | classifiers = [ 16 | 'Development Status :: 5 - Production/Stable', 17 | 'Intended Audience :: Education', 18 | 'Intended Audience :: Science/Research', 19 | 'License :: OSI Approved :: Apache Software License', 20 | 'Programming Language :: Python :: 3.8', 21 | 'Programming Language :: Python :: 3.9', 22 | 'Programming Language :: Python :: 3.10', 23 | 'Programming Language :: Python :: 3.11', 24 | 'Programming Language :: Python :: 3.12', 25 | 'Topic :: Scientific/Engineering', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'Topic :: Software Development', 28 | 'Topic :: Software Development :: Libraries', 29 | 'Topic :: Software Development :: Libraries :: Python Modules', 30 | ] 31 | dependencies = [ 32 | 'torch', 33 | 'torchvision', 34 | 'pyyaml', 35 | 'huggingface_hub', 36 | 'safetensors', 37 | ] 38 | dynamic = ["version"] 39 | 40 | [project.urls] 41 | homepage = "https://github.com/huggingface/pytorch-image-models" 42 | documentation = "https://huggingface.co/docs/timm/en/index" 43 | repository = "https://github.com/huggingface/pytorch-image-models" 44 | 45 | [tool.pdm.dev-dependencies] 46 | test = [ 47 | 'pytest', 48 | 'pytest-timeout', 49 | 'pytest-xdist', 50 | 'pytest-forked', 51 | 'expecttest', 52 | ] 53 | 54 | [tool.pdm.version] 55 | source = "file" 56 | path = "timm/version.py" 57 | 58 | [tool.pytest.ini_options] 59 | testpaths = ['tests'] 60 | markers = [ 61 | "base: marker for model tests using the basic setup", 62 | "cfg: marker for model tests checking the config", 63 | "torchscript: marker for model tests using torchscript", 64 | "features: marker for model tests checking feature extraction", 65 | "fxforward: marker for model tests using torch fx (only forward)", 66 | "fxbackward: marker for model tests using torch fx (only backward)", 67 | ] -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-timeout 3 | pytest-xdist 4 | pytest-forked 5 | expecttest 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision 3 | pyyaml 4 | huggingface_hub 5 | safetensors>=0.2 6 | numpy 7 | -------------------------------------------------------------------------------- /results/generate_csv_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | results = { 6 | 'results-imagenet.csv': [ 7 | 'results-imagenet-real.csv', 8 | 'results-imagenetv2-matched-frequency.csv', 9 | 'results-sketch.csv' 10 | ], 11 | 'results-imagenet-a-clean.csv': [ 12 | 'results-imagenet-a.csv', 13 | ], 14 | 'results-imagenet-r-clean.csv': [ 15 | 'results-imagenet-r.csv', 16 | ], 17 | } 18 | 19 | 20 | def diff(base_df, test_csv): 21 | base_df['mi'] = base_df.model + '-' + base_df.img_size.astype('str') 22 | base_models = base_df['mi'].values 23 | test_df = pd.read_csv(test_csv) 24 | test_df['mi'] = test_df.model + '-' + test_df.img_size.astype('str') 25 | test_models = test_df['mi'].values 26 | 27 | rank_diff = np.zeros_like(test_models, dtype='object') 28 | top1_diff = np.zeros_like(test_models, dtype='object') 29 | top5_diff = np.zeros_like(test_models, dtype='object') 30 | 31 | for rank, model in enumerate(test_models): 32 | if model in base_models: 33 | base_rank = int(np.where(base_models == model)[0]) 34 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank] 35 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank] 36 | 37 | # rank_diff 38 | if rank == base_rank: 39 | rank_diff[rank] = f'0' 40 | elif rank > base_rank: 41 | rank_diff[rank] = f'-{rank - base_rank}' 42 | else: 43 | rank_diff[rank] = f'+{base_rank - rank}' 44 | 45 | # top1_diff 46 | if top1_d >= .0: 47 | top1_diff[rank] = f'+{top1_d:.3f}' 48 | else: 49 | top1_diff[rank] = f'-{abs(top1_d):.3f}' 50 | 51 | # top5_diff 52 | if top5_d >= .0: 53 | top5_diff[rank] = f'+{top5_d:.3f}' 54 | else: 55 | top5_diff[rank] = f'-{abs(top5_d):.3f}' 56 | 57 | else: 58 | rank_diff[rank] = '' 59 | top1_diff[rank] = '' 60 | top5_diff[rank] = '' 61 | 62 | test_df['top1_diff'] = top1_diff 63 | test_df['top5_diff'] = top5_diff 64 | test_df['rank_diff'] = rank_diff 65 | 66 | test_df.drop('mi', axis=1, inplace=True) 67 | base_df.drop('mi', axis=1, inplace=True) 68 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) 69 | test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 70 | test_df.to_csv(test_csv, index=False, float_format='%.3f') 71 | 72 | 73 | for base_results, test_results in results.items(): 74 | base_df = pd.read_csv(base_results) 75 | base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 76 | for test_csv in test_results: 77 | diff(base_df, test_csv) 78 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format) 79 | base_df.to_csv(base_results, index=False, float_format='%.3f') 80 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [dist_conda] 2 | 3 | conda_name_differences = 'torch:pytorch' 4 | channels = pytorch 5 | noarch = True 6 | 7 | [metadata] 8 | 9 | url = "https://github.com/huggingface/pytorch-image-models" -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/pytorch-image-models/a22366e3ce52568193bd49d64f4e88fb01796965/tests/__init__.py -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ as __version__ 2 | from .layers import ( 3 | is_scriptable as is_scriptable, 4 | is_exportable as is_exportable, 5 | set_scriptable as set_scriptable, 6 | set_exportable as set_exportable, 7 | ) 8 | from .models import ( 9 | create_model as create_model, 10 | list_models as list_models, 11 | list_pretrained as list_pretrained, 12 | is_model as is_model, 13 | list_modules as list_modules, 14 | model_entrypoint as model_entrypoint, 15 | is_model_pretrained as is_model_pretrained, 16 | get_pretrained_cfg as get_pretrained_cfg, 17 | get_pretrained_cfg_value as get_pretrained_cfg_value, 18 | ) 19 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config, resolve_model_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .dataset_info import DatasetInfo, CustomDatasetInfo 8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset 9 | from .loader import create_loader 10 | from .mixup import Mixup, FastCollateMixup 11 | from .readers import create_reader 12 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 13 | from .real_labels import RealLabelsImagenet 14 | from .transforms import * 15 | from .transforms_factory import create_transform 16 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /timm/data/_info/mini_imagenet_indices.txt: -------------------------------------------------------------------------------- 1 | 12 2 | 15 3 | 51 4 | 64 5 | 70 6 | 96 7 | 99 8 | 107 9 | 111 10 | 121 11 | 149 12 | 166 13 | 173 14 | 176 15 | 207 16 | 214 17 | 228 18 | 242 19 | 244 20 | 245 21 | 249 22 | 251 23 | 256 24 | 266 25 | 270 26 | 275 27 | 279 28 | 291 29 | 299 30 | 301 31 | 306 32 | 310 33 | 359 34 | 364 35 | 392 36 | 403 37 | 412 38 | 427 39 | 440 40 | 454 41 | 471 42 | 476 43 | 478 44 | 484 45 | 494 46 | 502 47 | 503 48 | 507 49 | 519 50 | 524 51 | 533 52 | 538 53 | 546 54 | 553 55 | 556 56 | 567 57 | 569 58 | 584 59 | 597 60 | 602 61 | 604 62 | 605 63 | 629 64 | 655 65 | 657 66 | 659 67 | 683 68 | 687 69 | 702 70 | 709 71 | 713 72 | 735 73 | 741 74 | 758 75 | 779 76 | 781 77 | 800 78 | 801 79 | 807 80 | 815 81 | 819 82 | 847 83 | 854 84 | 858 85 | 860 86 | 880 87 | 881 88 | 883 89 | 909 90 | 912 91 | 914 92 | 919 93 | 925 94 | 927 95 | 934 96 | 950 97 | 972 98 | 973 99 | 997 100 | 998 101 | -------------------------------------------------------------------------------- /timm/data/_info/mini_imagenet_synsets.txt: -------------------------------------------------------------------------------- 1 | n01532829 2 | n01558993 3 | n01704323 4 | n01749939 5 | n01770081 6 | n01843383 7 | n01855672 8 | n01910747 9 | n01930112 10 | n01981276 11 | n02074367 12 | n02089867 13 | n02091244 14 | n02091831 15 | n02099601 16 | n02101006 17 | n02105505 18 | n02108089 19 | n02108551 20 | n02108915 21 | n02110063 22 | n02110341 23 | n02111277 24 | n02113712 25 | n02114548 26 | n02116738 27 | n02120079 28 | n02129165 29 | n02138441 30 | n02165456 31 | n02174001 32 | n02219486 33 | n02443484 34 | n02457408 35 | n02606052 36 | n02687172 37 | n02747177 38 | n02795169 39 | n02823428 40 | n02871525 41 | n02950826 42 | n02966193 43 | n02971356 44 | n02981792 45 | n03017168 46 | n03047690 47 | n03062245 48 | n03075370 49 | n03127925 50 | n03146219 51 | n03207743 52 | n03220513 53 | n03272010 54 | n03337140 55 | n03347037 56 | n03400231 57 | n03417042 58 | n03476684 59 | n03527444 60 | n03535780 61 | n03544143 62 | n03584254 63 | n03676483 64 | n03770439 65 | n03773504 66 | n03775546 67 | n03838899 68 | n03854065 69 | n03888605 70 | n03908618 71 | n03924679 72 | n03980874 73 | n03998194 74 | n04067472 75 | n04146614 76 | n04149813 77 | n04243546 78 | n04251144 79 | n04258138 80 | n04275548 81 | n04296562 82 | n04389033 83 | n04418357 84 | n04435653 85 | n04443257 86 | n04509417 87 | n04515003 88 | n04522168 89 | n04596742 90 | n04604644 91 | n04612504 92 | n06794110 93 | n07584110 94 | n07613480 95 | n07697537 96 | n07747607 97 | n09246464 98 | n09256479 99 | n13054560 100 | n13133613 101 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | DEFAULT_CROP_MODE = 'center' 3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 11 | -------------------------------------------------------------------------------- /timm/data/dataset_info.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Union 3 | 4 | 5 | class DatasetInfo(ABC): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | @abstractmethod 11 | def num_classes(self): 12 | pass 13 | 14 | @abstractmethod 15 | def label_names(self): 16 | pass 17 | 18 | @abstractmethod 19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 20 | pass 21 | 22 | @abstractmethod 23 | def index_to_label_name(self, index) -> str: 24 | pass 25 | 26 | @abstractmethod 27 | def index_to_description(self, index: int, detailed: bool = False) -> str: 28 | pass 29 | 30 | @abstractmethod 31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 32 | pass 33 | 34 | 35 | class CustomDatasetInfo(DatasetInfo): 36 | """ DatasetInfo that wraps passed values for custom datasets.""" 37 | 38 | def __init__( 39 | self, 40 | label_names: Union[List[str], Dict[int, str]], 41 | label_descriptions: Optional[Dict[str, str]] = None 42 | ): 43 | super().__init__() 44 | assert len(label_names) > 0 45 | self._label_names = label_names # label index => label name mapping 46 | self._label_descriptions = label_descriptions # label name => label description mapping 47 | if self._label_descriptions is not None: 48 | # validate descriptions (label names required) 49 | assert isinstance(self._label_descriptions, dict) 50 | for n in self._label_names: 51 | assert n in self._label_descriptions 52 | 53 | def num_classes(self): 54 | return len(self._label_names) 55 | 56 | def label_names(self): 57 | return self._label_names 58 | 59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 60 | return self._label_descriptions 61 | 62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 63 | if self._label_descriptions: 64 | return self._label_descriptions[label] 65 | return label # return label name itself if a descriptions is not present 66 | 67 | def index_to_label_name(self, index) -> str: 68 | assert 0 <= index < len(self._label_names) 69 | return self._label_names[index] 70 | 71 | def index_to_description(self, index: int, detailed: bool = False) -> str: 72 | label = self.index_to_label_name(index) 73 | return self.label_name_to_description(label, detailed=detailed) 74 | -------------------------------------------------------------------------------- /timm/data/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader_factory import create_reader 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /timm/data/readers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | def load_class_map(map_or_filename, root=''): 6 | if isinstance(map_or_filename, dict): 7 | assert dict, 'class_map dict must be non-empty' 8 | return map_or_filename 9 | class_map_path = map_or_filename 10 | if not os.path.exists(class_map_path): 11 | class_map_path = os.path.join(root, class_map_path) 12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 14 | if class_map_ext == '.txt': 15 | with open(class_map_path) as f: 16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 17 | elif class_map_ext == '.pkl': 18 | with open(class_map_path, 'rb') as f: 19 | class_to_idx = pickle.load(f) 20 | else: 21 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 22 | return class_to_idx 23 | 24 | -------------------------------------------------------------------------------- /timm/data/readers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /timm/data/readers/reader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Reader: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /timm/data/readers/reader_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from .reader_image_folder import ReaderImageFolder 5 | from .reader_image_in_tar import ReaderImageInTar 6 | 7 | 8 | def create_reader( 9 | name: str, 10 | root: Optional[str] = None, 11 | split: str = 'train', 12 | **kwargs, 13 | ): 14 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 15 | name = name.lower() 16 | name = name.split('/', 1) 17 | prefix = '' 18 | if len(name) > 1: 19 | prefix = name[0] 20 | name = name[-1] 21 | 22 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 23 | # explicitly select other options shortly 24 | if prefix == 'hfds': 25 | from .reader_hfds import ReaderHfds # defer Hf datasets import 26 | reader = ReaderHfds(name=name, root=root, split=split, **kwargs) 27 | elif prefix == 'hfids': 28 | from .reader_hfids import ReaderHfids # defer HF datasets import 29 | reader = ReaderHfids(name=name, root=root, split=split, **kwargs) 30 | elif prefix == 'tfds': 31 | from .reader_tfds import ReaderTfds # defer tensorflow import 32 | reader = ReaderTfds(name=name, root=root, split=split, **kwargs) 33 | elif prefix == 'wds': 34 | from .reader_wds import ReaderWds 35 | kwargs.pop('download', False) 36 | reader = ReaderWds(root=root, name=name, split=split, **kwargs) 37 | else: 38 | assert os.path.exists(root) 39 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 40 | # FIXME support split here or in reader? 41 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 42 | reader = ReaderImageInTar(root, **kwargs) 43 | else: 44 | reader = ReaderImageFolder(root, **kwargs) 45 | return reader 46 | -------------------------------------------------------------------------------- /timm/data/readers/reader_hfds.py: -------------------------------------------------------------------------------- 1 | """ Dataset reader that wraps Hugging Face datasets 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import io 6 | import math 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from PIL import Image 12 | 13 | try: 14 | import datasets 15 | except ImportError as e: 16 | print("Please install Hugging Face datasets package `pip install datasets`.") 17 | raise e 18 | from .class_map import load_class_map 19 | from .reader import Reader 20 | 21 | 22 | def get_class_labels(info, label_key='label'): 23 | if 'label' not in info.features: 24 | return {} 25 | class_label = info.features[label_key] 26 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names} 27 | return class_to_idx 28 | 29 | 30 | class ReaderHfds(Reader): 31 | 32 | def __init__( 33 | self, 34 | name: str, 35 | root: Optional[str] = None, 36 | split: str = 'train', 37 | class_map: dict = None, 38 | input_key: str = 'image', 39 | target_key: str = 'label', 40 | download: bool = False, 41 | trust_remote_code: bool = False 42 | ): 43 | """ 44 | """ 45 | super().__init__() 46 | self.root = root 47 | self.split = split 48 | self.dataset = datasets.load_dataset( 49 | name, # 'name' maps to path arg in hf datasets 50 | split=split, 51 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set 52 | trust_remote_code=trust_remote_code 53 | ) 54 | # leave decode for caller, plus we want easy access to original path names... 55 | self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False)) 56 | 57 | self.image_key = input_key 58 | self.label_key = target_key 59 | self.remap_class = False 60 | if class_map: 61 | self.class_to_idx = load_class_map(class_map) 62 | self.remap_class = True 63 | else: 64 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) 65 | self.split_info = self.dataset.info.splits[split] 66 | self.num_samples = self.split_info.num_examples 67 | 68 | def __getitem__(self, index): 69 | item = self.dataset[index] 70 | image = item[self.image_key] 71 | if 'bytes' in image and image['bytes']: 72 | image = io.BytesIO(image['bytes']) 73 | else: 74 | assert 'path' in image and image['path'] 75 | image = open(image['path'], 'rb') 76 | label = item[self.label_key] 77 | if self.remap_class: 78 | label = self.class_to_idx[label] 79 | return image, label 80 | 81 | def __len__(self): 82 | return len(self.dataset) 83 | 84 | def _filename(self, index, basename=False, absolute=False): 85 | item = self.dataset[index] 86 | return item[self.image_key]['path'] 87 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that extracts images from folders 2 | 3 | Folders are scanned recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | from typing import Dict, List, Optional, Set, Tuple, Union 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def find_images_and_targets( 19 | folder: str, 20 | types: Optional[Union[List, Tuple, Set]] = None, 21 | class_to_idx: Optional[Dict] = None, 22 | leaf_name_only: bool = True, 23 | sort: bool = True 24 | ): 25 | """ Walk folder recursively to discover images and map them to classes by folder names. 26 | 27 | Args: 28 | folder: root of folder to recursively search 29 | types: types (file extensions) to search for in path 30 | class_to_idx: specify mapping for class (folder name) to class index if set 31 | leaf_name_only: use only leaf-name of folder walk for class names 32 | sort: re-sort found images by name (for consistent ordering) 33 | 34 | Returns: 35 | A list of image and target tuples, class_to_idx mapping 36 | """ 37 | types = get_img_extensions(as_set=True) if not types else set(types) 38 | labels = [] 39 | filenames = [] 40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 41 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 43 | for f in files: 44 | base, ext = os.path.splitext(f) 45 | if ext.lower() in types: 46 | filenames.append(os.path.join(root, f)) 47 | labels.append(label) 48 | if class_to_idx is None: 49 | # building class index 50 | unique_labels = set(labels) 51 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 54 | if sort: 55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 56 | return images_and_targets, class_to_idx 57 | 58 | 59 | class ReaderImageFolder(Reader): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | class_map='', 65 | input_key=None, 66 | ): 67 | super().__init__() 68 | 69 | self.root = root 70 | class_to_idx = None 71 | if class_map: 72 | class_to_idx = load_class_map(class_map, root) 73 | find_types = None 74 | if input_key: 75 | find_types = input_key.split(';') 76 | self.samples, self.class_to_idx = find_images_and_targets( 77 | root, 78 | class_to_idx=class_to_idx, 79 | types=find_types, 80 | ) 81 | if len(self.samples) == 0: 82 | raise RuntimeError( 83 | f'Found 0 images in subfolders of {root}. ' 84 | f'Supported image extensions are {", ".join(get_img_extensions())}') 85 | 86 | def __getitem__(self, index): 87 | path, target = self.samples[index] 88 | return open(path, 'rb'), target 89 | 90 | def __len__(self): 91 | return len(self.samples) 92 | 93 | def _filename(self, index, basename=False, absolute=False): 94 | filename = self.samples[index][0] 95 | if basename: 96 | filename = os.path.basename(filename) 97 | elif not absolute: 98 | filename = os.path.relpath(filename, self.root) 99 | return filename 100 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that reads single tarfile based datasets 2 | 3 | This reader can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ReaderImageTar(Reader): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /timm/data/readers/shared_count.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value 2 | 3 | 4 | class SharedCount: 5 | def __init__(self, epoch: int = 0): 6 | self.shared_epoch = Value('i', epoch) 7 | 8 | @property 9 | def value(self): 10 | return self.shared_epoch.value 11 | 12 | @value.setter 13 | def value(self, epoch): 14 | self.shared_epoch.value = epoch 15 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | import pkgutil 11 | 12 | 13 | class RealLabelsImagenet: 14 | 15 | def __init__(self, filenames, real_json=None, topk=(1, 5)): 16 | if real_json is not None: 17 | with open(real_json) as real_labels: 18 | real_labels = json.load(real_labels) 19 | else: 20 | real_labels = json.loads( 21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8')) 22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 23 | self.real_labels = real_labels 24 | self.filenames = filenames 25 | assert len(self.filenames) == len(self.real_labels) 26 | self.topk = topk 27 | self.is_correct = {k: [] for k in topk} 28 | self.sample_idx = 0 29 | 30 | def add_result(self, output): 31 | maxk = max(self.topk) 32 | _, pred_batch = output.topk(maxk, 1, True, True) 33 | pred_batch = pred_batch.cpu().numpy() 34 | for pred in pred_batch: 35 | filename = self.filenames[self.sample_idx] 36 | filename = os.path.basename(filename) 37 | if self.real_labels[filename]: 38 | for k in self.topk: 39 | self.is_correct[k].append( 40 | any([p in self.real_labels[filename] for p in pred[:k]])) 41 | self.sample_idx += 1 42 | 43 | def get_accuracy(self, k=None): 44 | if k is None: 45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 46 | else: 47 | return float(np.mean(self.is_correct[k])) * 100 48 | -------------------------------------------------------------------------------- /timm/layers/attention_pool.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .config import use_fused_attn 8 | from .mlp import Mlp 9 | from .weight_init import trunc_normal_tf_ 10 | 11 | 12 | class AttentionPoolLatent(nn.Module): 13 | """ Attention pooling w/ latent query 14 | """ 15 | fused_attn: torch.jit.Final[bool] 16 | 17 | def __init__( 18 | self, 19 | in_features: int, 20 | out_features: int = None, 21 | embed_dim: int = None, 22 | num_heads: int = 8, 23 | feat_size: Optional[int] = None, 24 | mlp_ratio: float = 4.0, 25 | qkv_bias: bool = True, 26 | qk_norm: bool = False, 27 | latent_len: int = 1, 28 | latent_dim: int = None, 29 | pos_embed: str = '', 30 | pool_type: str = 'token', 31 | norm_layer: Optional[Type[nn.Module]] = None, 32 | act_layer: Optional[Type[nn.Module]] = nn.GELU, 33 | drop: float = 0.0, 34 | ): 35 | super().__init__() 36 | embed_dim = embed_dim or in_features 37 | out_features = out_features or in_features 38 | assert embed_dim % num_heads == 0 39 | self.num_heads = num_heads 40 | self.head_dim = embed_dim // num_heads 41 | self.feat_size = feat_size 42 | self.scale = self.head_dim ** -0.5 43 | self.pool = pool_type 44 | self.fused_attn = use_fused_attn() 45 | 46 | if pos_embed == 'abs': 47 | assert feat_size is not None 48 | self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) 49 | else: 50 | self.pos_embed = None 51 | 52 | self.latent_dim = latent_dim or embed_dim 53 | self.latent_len = latent_len 54 | self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) 55 | 56 | self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) 57 | self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) 58 | if qk_norm: 59 | qk_norm_layer = norm_layer or nn.LayerNorm 60 | self.q_norm = qk_norm_layer(self.head_dim) 61 | self.k_norm = qk_norm_layer(self.head_dim) 62 | else: 63 | self.q_norm = nn.Identity() 64 | self.k_norm = nn.Identity() 65 | self.proj = nn.Linear(embed_dim, embed_dim) 66 | self.proj_drop = nn.Dropout(drop) 67 | 68 | self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() 69 | self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer) 70 | 71 | self.init_weights() 72 | 73 | def init_weights(self): 74 | if self.pos_embed is not None: 75 | trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) 76 | trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) 77 | 78 | def forward(self, x): 79 | B, N, C = x.shape 80 | 81 | if self.pos_embed is not None: 82 | # FIXME interpolate 83 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) 84 | 85 | q_latent = self.latent.expand(B, -1, -1) 86 | q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) 87 | 88 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 89 | k, v = kv.unbind(0) 90 | 91 | q, k = self.q_norm(q), self.k_norm(k) 92 | 93 | if self.fused_attn: 94 | x = F.scaled_dot_product_attention(q, k, v) 95 | else: 96 | q = q * self.scale 97 | attn = q @ k.transpose(-2, -1) 98 | attn = attn.softmax(dim=-1) 99 | x = attn @ v 100 | x = x.transpose(1, 2).reshape(B, self.latent_len, C) 101 | x = self.proj(x) 102 | x = self.proj_drop(x) 103 | 104 | x = x + self.mlp(self.norm(x)) 105 | 106 | # optional pool if latent seq_len > 1 and pooled output is desired 107 | if self.pool == 'token': 108 | x = x[:, 0] 109 | elif self.pool == 'avg': 110 | x = x.mean(1) 111 | return x -------------------------------------------------------------------------------- /timm/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | from functools import partial 9 | from typing import Optional, Type 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | from .padding import get_padding 17 | from .typing import LayerType 18 | 19 | 20 | class BlurPool2d(nn.Module): 21 | r"""Creates a module that computes blurs and downsample a given feature map. 22 | See :cite:`zhang2019shiftinvar` for more details. 23 | Corresponds to the Downsample class, which does blurring and subsampling 24 | 25 | Args: 26 | channels = Number of input channels 27 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 28 | stride (int): downsampling filter stride 29 | 30 | Returns: 31 | torch.Tensor: the transformed tensor. 32 | """ 33 | def __init__( 34 | self, 35 | channels: Optional[int] = None, 36 | filt_size: int = 3, 37 | stride: int = 2, 38 | pad_mode: str = 'reflect', 39 | ) -> None: 40 | super(BlurPool2d, self).__init__() 41 | assert filt_size > 1 42 | self.channels = channels 43 | self.filt_size = filt_size 44 | self.stride = stride 45 | self.pad_mode = pad_mode 46 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 47 | 48 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 49 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :] 50 | if channels is not None: 51 | blur_filter = blur_filter.repeat(self.channels, 1, 1, 1) 52 | self.register_buffer('filt', blur_filter, persistent=False) 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | x = F.pad(x, self.padding, mode=self.pad_mode) 56 | if self.channels is None: 57 | channels = x.shape[1] 58 | weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size) 59 | else: 60 | channels = self.channels 61 | weight = self.filt 62 | return F.conv2d(x, weight, stride=self.stride, groups=channels) 63 | 64 | 65 | def create_aa( 66 | aa_layer: LayerType, 67 | channels: Optional[int] = None, 68 | stride: int = 2, 69 | enable: bool = True, 70 | noop: Optional[Type[nn.Module]] = nn.Identity 71 | ) -> nn.Module: 72 | """ Anti-aliasing """ 73 | if not aa_layer or not enable: 74 | return noop() if noop is not None else None 75 | 76 | if isinstance(aa_layer, str): 77 | aa_layer = aa_layer.lower().replace('_', '').replace('-', '') 78 | if aa_layer == 'avg' or aa_layer == 'avgpool': 79 | aa_layer = nn.AvgPool2d 80 | elif aa_layer == 'blur' or aa_layer == 'blurpool': 81 | aa_layer = BlurPool2d 82 | elif aa_layer == 'blurpc': 83 | aa_layer = partial(BlurPool2d, pad_mode='constant') 84 | 85 | else: 86 | assert False, f"Unknown anti-aliasing layer ({aa_layer})." 87 | 88 | try: 89 | return aa_layer(channels=channels, stride=stride) 90 | except TypeError as e: 91 | return aa_layer(stride) 92 | -------------------------------------------------------------------------------- /timm/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .config import is_exportable, is_scriptable 11 | from .padding import pad_same, pad_same_arg, get_padding_value 12 | 13 | 14 | _USE_EXPORT_CONV = False 15 | 16 | 17 | def conv2d_same( 18 | x, 19 | weight: torch.Tensor, 20 | bias: Optional[torch.Tensor] = None, 21 | stride: Tuple[int, int] = (1, 1), 22 | padding: Tuple[int, int] = (0, 0), 23 | dilation: Tuple[int, int] = (1, 1), 24 | groups: int = 1, 25 | ): 26 | x = pad_same(x, weight.shape[-2:], stride, dilation) 27 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 28 | 29 | 30 | class Conv2dSame(nn.Conv2d): 31 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 32 | """ 33 | 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels, 38 | kernel_size, 39 | stride=1, 40 | padding=0, 41 | dilation=1, 42 | groups=1, 43 | bias=True, 44 | ): 45 | super(Conv2dSame, self).__init__( 46 | in_channels, out_channels, kernel_size, 47 | stride, 0, dilation, groups, bias, 48 | ) 49 | 50 | def forward(self, x): 51 | return conv2d_same( 52 | x, self.weight, self.bias, 53 | self.stride, self.padding, self.dilation, self.groups, 54 | ) 55 | 56 | 57 | class Conv2dSameExport(nn.Conv2d): 58 | """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions 59 | 60 | NOTE: This does not currently work with torch.jit.script 61 | """ 62 | 63 | # pylint: disable=unused-argument 64 | def __init__( 65 | self, 66 | in_channels, 67 | out_channels, 68 | kernel_size, 69 | stride=1, 70 | padding=0, 71 | dilation=1, 72 | groups=1, 73 | bias=True, 74 | ): 75 | super(Conv2dSameExport, self).__init__( 76 | in_channels, out_channels, kernel_size, 77 | stride, 0, dilation, groups, bias, 78 | ) 79 | self.pad = None 80 | self.pad_input_size = (0, 0) 81 | 82 | def forward(self, x): 83 | input_size = x.size()[-2:] 84 | if self.pad is None: 85 | pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) 86 | self.pad = nn.ZeroPad2d(pad_arg) 87 | self.pad_input_size = input_size 88 | 89 | x = self.pad(x) 90 | return F.conv2d( 91 | x, self.weight, self.bias, 92 | self.stride, self.padding, self.dilation, self.groups, 93 | ) 94 | 95 | 96 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 97 | padding = kwargs.pop('padding', '') 98 | kwargs.setdefault('bias', False) 99 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 100 | if is_dynamic: 101 | if _USE_EXPORT_CONV and is_exportable(): 102 | # older PyTorch ver needed this to export same padding reasonably 103 | assert not is_scriptable() # Conv2DSameExport does not work with jit 104 | return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) 105 | else: 106 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 107 | else: 108 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 109 | 110 | 111 | -------------------------------------------------------------------------------- /timm/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from typing import Any, Dict, Optional, Type 6 | 7 | from torch import nn as nn 8 | 9 | from .typing import LayerType, PadType 10 | from .blur_pool import create_aa 11 | from .create_conv2d import create_conv2d 12 | from .create_norm_act import get_norm_act_layer 13 | 14 | 15 | class ConvNormAct(nn.Module): 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: int, 20 | kernel_size: int = 1, 21 | stride: int = 1, 22 | padding: PadType = '', 23 | dilation: int = 1, 24 | groups: int = 1, 25 | bias: bool = False, 26 | apply_norm: bool = True, 27 | apply_act: bool = True, 28 | norm_layer: LayerType = nn.BatchNorm2d, 29 | act_layer: Optional[LayerType] = nn.ReLU, 30 | aa_layer: Optional[LayerType] = None, 31 | drop_layer: Optional[Type[nn.Module]] = None, 32 | conv_kwargs: Optional[Dict[str, Any]] = None, 33 | norm_kwargs: Optional[Dict[str, Any]] = None, 34 | act_kwargs: Optional[Dict[str, Any]] = None, 35 | ): 36 | super(ConvNormAct, self).__init__() 37 | conv_kwargs = conv_kwargs or {} 38 | norm_kwargs = norm_kwargs or {} 39 | act_kwargs = act_kwargs or {} 40 | use_aa = aa_layer is not None and stride > 1 41 | 42 | self.conv = create_conv2d( 43 | in_channels, 44 | out_channels, 45 | kernel_size, 46 | stride=1 if use_aa else stride, 47 | padding=padding, 48 | dilation=dilation, 49 | groups=groups, 50 | bias=bias, 51 | **conv_kwargs, 52 | ) 53 | 54 | if apply_norm: 55 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 56 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 57 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 58 | if drop_layer: 59 | norm_kwargs['drop_layer'] = drop_layer 60 | self.bn = norm_act_layer( 61 | out_channels, 62 | apply_act=apply_act, 63 | act_kwargs=act_kwargs, 64 | **norm_kwargs, 65 | ) 66 | else: 67 | self.bn = nn.Sequential() 68 | if drop_layer: 69 | norm_kwargs['drop_layer'] = drop_layer 70 | self.bn.add_module('drop', drop_layer()) 71 | 72 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None) 73 | 74 | @property 75 | def in_channels(self): 76 | return self.conv.in_channels 77 | 78 | @property 79 | def out_channels(self): 80 | return self.conv.out_channels 81 | 82 | def forward(self, x): 83 | x = self.conv(x) 84 | x = self.bn(x) 85 | aa = getattr(self, 'aa', None) 86 | if aa is not None: 87 | x = self.aa(x) 88 | return x 89 | 90 | 91 | ConvBnAct = ConvNormAct 92 | ConvNormActAa = ConvNormAct # backwards compat, when they were separate 93 | -------------------------------------------------------------------------------- /timm/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /timm/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /timm/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import functools 8 | import types 9 | from typing import Type 10 | 11 | import torch.nn as nn 12 | 13 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d 14 | from torchvision.ops.misc import FrozenBatchNorm2d 15 | 16 | _NORM_MAP = dict( 17 | batchnorm=nn.BatchNorm2d, 18 | batchnorm2d=nn.BatchNorm2d, 19 | batchnorm1d=nn.BatchNorm1d, 20 | groupnorm=GroupNorm, 21 | groupnorm1=GroupNorm1, 22 | layernorm=LayerNorm, 23 | layernorm2d=LayerNorm2d, 24 | rmsnorm=RmsNorm, 25 | rmsnorm2d=RmsNorm2d, 26 | simplenorm=SimpleNorm, 27 | simplenorm2d=SimpleNorm2d, 28 | frozenbatchnorm2d=FrozenBatchNorm2d, 29 | ) 30 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 31 | 32 | 33 | def create_norm_layer(layer_name, num_features, **kwargs): 34 | layer = get_norm_layer(layer_name) 35 | layer_instance = layer(num_features, **kwargs) 36 | return layer_instance 37 | 38 | 39 | def get_norm_layer(norm_layer): 40 | if norm_layer is None: 41 | return None 42 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 43 | norm_kwargs = {} 44 | 45 | # unbind partial fn, so args can be rebound later 46 | if isinstance(norm_layer, functools.partial): 47 | norm_kwargs.update(norm_layer.keywords) 48 | norm_layer = norm_layer.func 49 | 50 | if isinstance(norm_layer, str): 51 | if not norm_layer: 52 | return None 53 | layer_name = norm_layer.replace('_', '').lower() 54 | norm_layer = _NORM_MAP[layer_name] 55 | else: 56 | norm_layer = norm_layer 57 | 58 | if norm_kwargs: 59 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 60 | return norm_layer 61 | -------------------------------------------------------------------------------- /timm/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalization + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | instances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | if norm_layer is None: 54 | return None 55 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 56 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 57 | norm_act_kwargs = {} 58 | 59 | # unbind partial fn, so args can be rebound later 60 | if isinstance(norm_layer, functools.partial): 61 | norm_act_kwargs.update(norm_layer.keywords) 62 | norm_layer = norm_layer.func 63 | 64 | if isinstance(norm_layer, str): 65 | if not norm_layer: 66 | return None 67 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 68 | norm_act_layer = _NORM_ACT_MAP[layer_name] 69 | elif norm_layer in _NORM_ACT_TYPES: 70 | norm_act_layer = norm_layer 71 | elif isinstance(norm_layer, types.FunctionType): 72 | # if function type, must be a lambda/fn that creates a norm_act layer 73 | norm_act_layer = norm_layer 74 | else: 75 | type_name = norm_layer.__name__.lower() 76 | if type_name.startswith('batchnorm'): 77 | norm_act_layer = BatchNormAct2d 78 | elif type_name.startswith('groupnorm'): 79 | norm_act_layer = GroupNormAct 80 | elif type_name.startswith('groupnorm1'): 81 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 82 | elif type_name.startswith('layernorm2d'): 83 | norm_act_layer = LayerNormAct2d 84 | elif type_name.startswith('layernorm'): 85 | norm_act_layer = LayerNormAct 86 | else: 87 | assert False, f"No equivalent norm_act layer for {type_name}" 88 | 89 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 90 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 91 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 92 | norm_act_kwargs.setdefault('act_layer', act_layer) 93 | if norm_act_kwargs: 94 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 95 | return norm_act_layer 96 | -------------------------------------------------------------------------------- /timm/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /timm/layers/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | class Format(str, Enum): 8 | NCHW = 'NCHW' 9 | NHWC = 'NHWC' 10 | NCL = 'NCL' 11 | NLC = 'NLC' 12 | 13 | 14 | FormatT = Union[str, Format] 15 | 16 | 17 | def get_spatial_dim(fmt: FormatT): 18 | fmt = Format(fmt) 19 | if fmt is Format.NLC: 20 | dim = (1,) 21 | elif fmt is Format.NCL: 22 | dim = (2,) 23 | elif fmt is Format.NHWC: 24 | dim = (1, 2) 25 | else: 26 | dim = (2, 3) 27 | return dim 28 | 29 | 30 | def get_channel_dim(fmt: FormatT): 31 | fmt = Format(fmt) 32 | if fmt is Format.NHWC: 33 | dim = 3 34 | elif fmt is Format.NLC: 35 | dim = 2 36 | else: 37 | dim = 1 38 | return dim 39 | 40 | 41 | def nchw_to(x: torch.Tensor, fmt: Format): 42 | if fmt == Format.NHWC: 43 | x = x.permute(0, 2, 3, 1) 44 | elif fmt == Format.NLC: 45 | x = x.flatten(2).transpose(1, 2) 46 | elif fmt == Format.NCL: 47 | x = x.flatten(2) 48 | return x 49 | 50 | 51 | def nhwc_to(x: torch.Tensor, fmt: Format): 52 | if fmt == Format.NCHW: 53 | x = x.permute(0, 3, 1, 2) 54 | elif fmt == Format.NLC: 55 | x = x.flatten(1, 2) 56 | elif fmt == Format.NCL: 57 | x = x.flatten(1, 2).transpose(1, 2) 58 | return x 59 | -------------------------------------------------------------------------------- /timm/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /timm/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /timm/layers/grid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: 7 | """generate N-D grid in dimension order. 8 | 9 | The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. 10 | 11 | That is, the statement 12 | [X1,X2,X3] = ndgrid(x1,x2,x3) 13 | 14 | produces the same result as 15 | 16 | [X2,X1,X3] = meshgrid(x2,x1,x3) 17 | 18 | This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make 19 | torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). 20 | 21 | """ 22 | try: 23 | return torch.meshgrid(*tensors, indexing='ij') 24 | except TypeError: 25 | # old PyTorch < 1.10 will follow this path as it does not have indexing arg, 26 | # the old behaviour of meshgrid was 'ij' 27 | return torch.meshgrid(*tensors) 28 | 29 | 30 | def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]: 31 | """generate N-D grid in spatial dim order. 32 | 33 | The meshgrid function is similar to ndgrid except that the order of the 34 | first two input and output arguments is switched. 35 | 36 | That is, the statement 37 | 38 | [X,Y,Z] = meshgrid(x,y,z) 39 | produces the same result as 40 | 41 | [Y,X,Z] = ndgrid(y,x,z) 42 | Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space, 43 | while ndgrid is better suited to multidimensional problems that aren't spatially based. 44 | """ 45 | 46 | # NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have 47 | # capability of generating grid in xy order before then. 48 | return torch.meshgrid(*tensors, indexing='xy') 49 | 50 | -------------------------------------------------------------------------------- /timm/layers/grn.py: -------------------------------------------------------------------------------- 1 | """ Global Response Normalization Module 2 | 3 | Based on the GRN layer presented in 4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 5 | 6 | This implementation 7 | * works for both NCHW and NHWC tensor layouts 8 | * uses affine param names matching existing torch norm layers 9 | * slightly improves eager mode performance via fused addcmul 10 | 11 | Hacked together by / Copyright 2023 Ross Wightman 12 | """ 13 | 14 | import torch 15 | from torch import nn as nn 16 | 17 | 18 | class GlobalResponseNorm(nn.Module): 19 | """ Global Response Normalization layer 20 | """ 21 | def __init__(self, dim, eps=1e-6, channels_last=True): 22 | super().__init__() 23 | self.eps = eps 24 | if channels_last: 25 | self.spatial_dim = (1, 2) 26 | self.channel_dim = -1 27 | self.wb_shape = (1, 1, 1, -1) 28 | else: 29 | self.spatial_dim = (2, 3) 30 | self.channel_dim = 1 31 | self.wb_shape = (1, -1, 1, 1) 32 | 33 | self.weight = nn.Parameter(torch.zeros(dim)) 34 | self.bias = nn.Parameter(torch.zeros(dim)) 35 | 36 | def forward(self, x): 37 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) 38 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) 39 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n) 40 | -------------------------------------------------------------------------------- /timm/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return tuple(x) 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pads a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /timm/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /timm/layers/interpolate.py: -------------------------------------------------------------------------------- 1 | """ Interpolation helpers for timm layers 2 | 3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations 4 | Copyright Shane Barratt, Apache 2.0 license 5 | """ 6 | import torch 7 | from itertools import product 8 | 9 | 10 | class RegularGridInterpolator: 11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing. 12 | Produces similar results to scipy RegularGridInterpolator or interp2d 13 | in 'linear' mode. 14 | 15 | Taken from https://github.com/sbarratt/torch_interpolations 16 | """ 17 | 18 | def __init__(self, points, values): 19 | self.points = points 20 | self.values = values 21 | 22 | assert isinstance(self.points, tuple) or isinstance(self.points, list) 23 | assert isinstance(self.values, torch.Tensor) 24 | 25 | self.ms = list(self.values.shape) 26 | self.n = len(self.points) 27 | 28 | assert len(self.ms) == self.n 29 | 30 | for i, p in enumerate(self.points): 31 | assert isinstance(p, torch.Tensor) 32 | assert p.shape[0] == self.values.shape[i] 33 | 34 | def __call__(self, points_to_interp): 35 | assert self.points is not None 36 | assert self.values is not None 37 | 38 | assert len(points_to_interp) == len(self.points) 39 | K = points_to_interp[0].shape[0] 40 | for x in points_to_interp: 41 | assert x.shape[0] == K 42 | 43 | idxs = [] 44 | dists = [] 45 | overalls = [] 46 | for p, x in zip(self.points, points_to_interp): 47 | idx_right = torch.bucketize(x, p) 48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) 50 | dist_left = x - p[idx_left] 51 | dist_right = p[idx_right] - x 52 | dist_left[dist_left < 0] = 0. 53 | dist_right[dist_right < 0] = 0. 54 | both_zero = (dist_left == 0) & (dist_right == 0) 55 | dist_left[both_zero] = dist_right[both_zero] = 1. 56 | 57 | idxs.append((idx_left, idx_right)) 58 | dists.append((dist_left, dist_right)) 59 | overalls.append(dist_left + dist_right) 60 | 61 | numerator = 0. 62 | for indexer in product([0, 1], repeat=self.n): 63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] 64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] 65 | numerator += self.values[as_s] * \ 66 | torch.prod(torch.stack(bs_s), dim=0) 67 | denominator = torch.prod(torch.stack(overalls), dim=0) 68 | return numerator / denominator 69 | -------------------------------------------------------------------------------- /timm/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LayerScale(nn.Module): 6 | """ LayerScale on tensors with channels in last-dim. 7 | """ 8 | def __init__( 9 | self, 10 | dim: int, 11 | init_values: float = 1e-5, 12 | inplace: bool = False, 13 | ) -> None: 14 | super().__init__() 15 | self.inplace = inplace 16 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 20 | 21 | 22 | class LayerScale2d(nn.Module): 23 | """ LayerScale for tensors with torch 2D NCHW layout. 24 | """ 25 | def __init__( 26 | self, 27 | dim: int, 28 | init_values: float = 1e-5, 29 | inplace: bool = False, 30 | ): 31 | super().__init__() 32 | self.inplace = inplace 33 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 34 | 35 | def forward(self, x): 36 | gamma = self.gamma.view(1, -1, 1, 1) 37 | return x.mul_(gamma) if self.inplace else x * gamma 38 | 39 | -------------------------------------------------------------------------------- /timm/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /timm/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple, Union 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from .helpers import to_2tuple 12 | 13 | 14 | # Calculate symmetric padding for a convolution 15 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]: 16 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): 17 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) 18 | return [get_padding(*a) for a in zip(kernel_size, stride, dilation)] 19 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 20 | return padding 21 | 22 | 23 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 24 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): 25 | if isinstance(x, torch.Tensor): 26 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0) 27 | else: 28 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) 29 | 30 | 31 | # Can SAME padding for given args be done statically? 32 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 33 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): 34 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) 35 | return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)]) 36 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 37 | 38 | 39 | def pad_same_arg( 40 | input_size: List[int], 41 | kernel_size: List[int], 42 | stride: List[int], 43 | dilation: List[int] = (1, 1), 44 | ) -> List[int]: 45 | ih, iw = input_size 46 | kh, kw = kernel_size 47 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0]) 48 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1]) 49 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 50 | 51 | 52 | # Dynamically pad input x with 'SAME' padding for conv with specified args 53 | def pad_same( 54 | x, 55 | kernel_size: List[int], 56 | stride: List[int], 57 | dilation: List[int] = (1, 1), 58 | value: float = 0, 59 | ): 60 | ih, iw = x.size()[-2:] 61 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) 62 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) 63 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value) 64 | return x 65 | 66 | 67 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 68 | dynamic = False 69 | if isinstance(padding, str): 70 | # for any string padding, the padding will be calculated for you, one of three ways 71 | padding = padding.lower() 72 | if padding == 'same': 73 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 74 | if is_static_pad(kernel_size, **kwargs): 75 | # static case, no extra overhead 76 | padding = get_padding(kernel_size, **kwargs) 77 | else: 78 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 79 | padding = 0 80 | dynamic = True 81 | elif padding == 'valid': 82 | # 'VALID' padding, same as padding=0 83 | padding = 0 84 | else: 85 | # Default to PyTorch style 'same'-ish symmetric padding 86 | padding = get_padding(kernel_size, **kwargs) 87 | return padding, dynamic 88 | -------------------------------------------------------------------------------- /timm/layers/patch_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PatchDropout(nn.Module): 8 | """ 9 | https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 10 | """ 11 | return_indices: torch.jit.Final[bool] 12 | 13 | def __init__( 14 | self, 15 | prob: float = 0.5, 16 | num_prefix_tokens: int = 1, 17 | ordered: bool = False, 18 | return_indices: bool = False, 19 | ): 20 | super().__init__() 21 | assert 0 <= prob < 1. 22 | self.prob = prob 23 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) 24 | self.ordered = ordered 25 | self.return_indices = return_indices 26 | 27 | def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: 28 | if not self.training or self.prob == 0.: 29 | if self.return_indices: 30 | return x, None 31 | return x 32 | 33 | if self.num_prefix_tokens: 34 | prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] 35 | else: 36 | prefix_tokens = None 37 | 38 | B = x.shape[0] 39 | L = x.shape[1] 40 | num_keep = max(1, int(L * (1. - self.prob))) 41 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] 42 | if self.ordered: 43 | # NOTE does not need to maintain patch order in typical transformer use, 44 | # but possibly useful for debug / visualization 45 | keep_indices = keep_indices.sort(dim=-1)[0] 46 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) 47 | 48 | if prefix_tokens is not None: 49 | x = torch.cat((prefix_tokens, x), dim=1) 50 | 51 | if self.return_indices: 52 | return x, keep_indices 53 | return x 54 | -------------------------------------------------------------------------------- /timm/layers/pool1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def global_pool_nlc( 5 | x: torch.Tensor, 6 | pool_type: str = 'token', 7 | num_prefix_tokens: int = 1, 8 | reduce_include_prefix: bool = False, 9 | ): 10 | if not pool_type: 11 | return x 12 | 13 | if pool_type == 'token': 14 | x = x[:, 0] # class token 15 | else: 16 | x = x if reduce_include_prefix else x[:, num_prefix_tokens:] 17 | if pool_type == 'avg': 18 | x = x.mean(dim=1) 19 | elif pool_type == 'avgmax': 20 | x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) 21 | elif pool_type == 'max': 22 | x = x.amax(dim=1) 23 | else: 24 | assert not pool_type, f'Unknown pool type {pool_type}' 25 | 26 | return x -------------------------------------------------------------------------------- /timm/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /timm/layers/pos_embed.py: -------------------------------------------------------------------------------- 1 | """ Position Embedding Utilities 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import logging 6 | import math 7 | from typing import List, Tuple, Optional, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .helpers import to_2tuple 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | def resample_abs_pos_embed( 18 | posemb: torch.Tensor, 19 | new_size: List[int], 20 | old_size: Optional[List[int]] = None, 21 | num_prefix_tokens: int = 1, 22 | interpolation: str = 'bicubic', 23 | antialias: bool = True, 24 | verbose: bool = False, 25 | ): 26 | # sort out sizes, assume square if old size not provided 27 | num_pos_tokens = posemb.shape[1] 28 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens 29 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: 30 | return posemb 31 | 32 | if old_size is None: 33 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) 34 | old_size = hw, hw 35 | 36 | if num_prefix_tokens: 37 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] 38 | else: 39 | posemb_prefix, posemb = None, posemb 40 | 41 | # do the interpolation 42 | embed_dim = posemb.shape[-1] 43 | orig_dtype = posemb.dtype 44 | posemb = posemb.float() # interpolate needs float32 45 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) 46 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 47 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) 48 | posemb = posemb.to(orig_dtype) 49 | 50 | # add back extra (class, etc) prefix tokens 51 | if posemb_prefix is not None: 52 | posemb = torch.cat([posemb_prefix, posemb], dim=1) 53 | 54 | if not torch.jit.is_scripting() and verbose: 55 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.') 56 | 57 | return posemb 58 | 59 | 60 | def resample_abs_pos_embed_nhwc( 61 | posemb: torch.Tensor, 62 | new_size: List[int], 63 | interpolation: str = 'bicubic', 64 | antialias: bool = True, 65 | verbose: bool = False, 66 | ): 67 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: 68 | return posemb 69 | 70 | orig_dtype = posemb.dtype 71 | posemb = posemb.float() 72 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) 73 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 74 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) 75 | 76 | if not torch.jit.is_scripting() and verbose: 77 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') 78 | 79 | return posemb 80 | -------------------------------------------------------------------------------- /timm/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /timm/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | bs: torch.jit.Final[int] 7 | 8 | def __init__(self, block_size=4): 9 | super().__init__() 10 | assert block_size == 4 11 | self.bs = block_size 12 | 13 | def forward(self, x): 14 | N, C, H, W = x.size() 15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 18 | return x 19 | 20 | 21 | class DepthToSpace(nn.Module): 22 | 23 | def __init__(self, block_size): 24 | super().__init__() 25 | self.bs = block_size 26 | 27 | def forward(self, x): 28 | N, C, H, W = x.size() 29 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 30 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 31 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 32 | return x 33 | -------------------------------------------------------------------------------- /timm/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /timm/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /timm/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /timm/layers/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Type, Union 2 | 3 | import torch 4 | 5 | 6 | LayerType = Union[str, Callable, Type[torch.nn.Module]] 7 | PadType = Union[str, int, Tuple[int, int]] 8 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, 18 | smoothing=0.1, 19 | target_threshold: Optional[float] = None, 20 | weight: Optional[torch.Tensor] = None, 21 | reduction: str = 'mean', 22 | sum_classes: bool = False, 23 | pos_weight: Optional[Union[torch.Tensor, float]] = None, 24 | ): 25 | super(BinaryCrossEntropy, self).__init__() 26 | assert 0. <= smoothing < 1.0 27 | if pos_weight is not None: 28 | if not isinstance(pos_weight, torch.Tensor): 29 | pos_weight = torch.tensor(pos_weight) 30 | self.smoothing = smoothing 31 | self.target_threshold = target_threshold 32 | self.reduction = 'none' if sum_classes else reduction 33 | self.sum_classes = sum_classes 34 | self.register_buffer('weight', weight) 35 | self.register_buffer('pos_weight', pos_weight) 36 | 37 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 38 | batch_size = x.shape[0] 39 | assert batch_size == target.shape[0] 40 | 41 | if target.shape != x.shape: 42 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 43 | num_classes = x.shape[-1] 44 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 45 | off_value = self.smoothing / num_classes 46 | on_value = 1. - self.smoothing + off_value 47 | target = target.long().view(-1, 1) 48 | target = torch.full( 49 | (batch_size, num_classes), 50 | off_value, 51 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 52 | 53 | if self.target_threshold is not None: 54 | # Make target 0, or 1 if threshold set 55 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 56 | 57 | loss = F.binary_cross_entropy_with_logits( 58 | x, target, 59 | self.weight, 60 | pos_weight=self.pos_weight, 61 | reduction=self.reduction, 62 | ) 63 | if self.sum_classes: 64 | loss = loss.sum(-1).mean() 65 | return loss 66 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/models/_pretrained.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import deque, defaultdict 3 | from dataclasses import dataclass, field, replace, asdict 4 | from typing import Any, Deque, Dict, Tuple, Optional, Union 5 | 6 | 7 | __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg'] 8 | 9 | 10 | @dataclass 11 | class PretrainedCfg: 12 | """ 13 | """ 14 | # weight source locations 15 | url: Optional[Union[str, Tuple[str, str]]] = None # remote URL 16 | file: Optional[str] = None # local / shared filesystem path 17 | state_dict: Optional[Dict[str, Any]] = None # in-memory state dict 18 | hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model') 19 | hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default) 20 | 21 | source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) 22 | architecture: Optional[str] = None # architecture variant can be set when not implicit 23 | tag: Optional[str] = None # pretrained tag of source 24 | custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) 25 | 26 | # input / data config 27 | input_size: Tuple[int, int, int] = (3, 224, 224) 28 | test_input_size: Optional[Tuple[int, int, int]] = None 29 | min_input_size: Optional[Tuple[int, int, int]] = None 30 | fixed_input_size: bool = False 31 | interpolation: str = 'bicubic' 32 | crop_pct: float = 0.875 33 | test_crop_pct: Optional[float] = None 34 | crop_mode: str = 'center' 35 | mean: Tuple[float, ...] = (0.485, 0.456, 0.406) 36 | std: Tuple[float, ...] = (0.229, 0.224, 0.225) 37 | 38 | # head / classifier config and meta-data 39 | num_classes: int = 1000 40 | label_offset: Optional[int] = None 41 | label_names: Optional[Tuple[str]] = None 42 | label_descriptions: Optional[Dict[str, str]] = None 43 | 44 | # model attributes that vary with above or required for pretrained adaptation 45 | pool_size: Optional[Tuple[int, ...]] = None 46 | test_pool_size: Optional[Tuple[int, ...]] = None 47 | first_conv: Optional[str] = None 48 | classifier: Optional[str] = None 49 | 50 | license: Optional[str] = None 51 | description: Optional[str] = None 52 | origin_url: Optional[str] = None 53 | paper_name: Optional[str] = None 54 | paper_ids: Optional[Union[str, Tuple[str]]] = None 55 | notes: Optional[Tuple[str]] = None 56 | 57 | @property 58 | def has_weights(self): 59 | return self.url or self.file or self.hf_hub_id 60 | 61 | def to_dict(self, remove_source=False, remove_null=True): 62 | return filter_pretrained_cfg( 63 | asdict(self), 64 | remove_source=remove_source, 65 | remove_null=remove_null 66 | ) 67 | 68 | 69 | def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): 70 | filtered_cfg = {} 71 | keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none 72 | for k, v in cfg.items(): 73 | if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: 74 | continue 75 | if remove_null and v is None and k not in keep_null: 76 | continue 77 | filtered_cfg[k] = v 78 | return filtered_cfg 79 | 80 | 81 | @dataclass 82 | class DefaultCfg: 83 | tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default) 84 | cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag 85 | is_pretrained: bool = False # at least one of the configs has a pretrained source set 86 | 87 | @property 88 | def default(self): 89 | return self.cfgs[self.tags[0]] 90 | 91 | @property 92 | def default_with_tag(self): 93 | tag = self.tags[0] 94 | return tag, self.cfgs[tag] 95 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from ._factory import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/features.py: -------------------------------------------------------------------------------- 1 | from ._features import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/fx_features.py: -------------------------------------------------------------------------------- 1 | from ._features_fx import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/helpers.py: -------------------------------------------------------------------------------- 1 | from ._builder import * 2 | from ._helpers import * 3 | from ._manipulate import * 4 | from ._prune import * 5 | 6 | import warnings 7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 8 | -------------------------------------------------------------------------------- /timm/models/hub.py: -------------------------------------------------------------------------------- 1 | from ._hub import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition 2 | from timm.layers.activations import * 3 | from timm.layers.adaptive_avgmax_pool import \ 4 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 5 | from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding 6 | from timm.layers.blur_pool import BlurPool2d 7 | from timm.layers.classifier import ClassifierHead, create_classifier 8 | from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer 9 | from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 10 | set_layer_config 11 | from timm.layers.conv2d_same import Conv2dSame, conv2d_same 12 | from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 13 | from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn 14 | from timm.layers.create_attn import get_attn, create_attn 15 | from timm.layers.create_conv2d import create_conv2d 16 | from timm.layers.create_norm import get_norm_layer, create_norm_layer 17 | from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 18 | from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path 19 | from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 20 | from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 22 | from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 23 | from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 24 | from timm.layers.gather_excite import GatherExcite 25 | from timm.layers.global_context import GlobalContext 26 | from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 27 | from timm.layers.inplace_abn import InplaceAbn 28 | from timm.layers.linear import Linear 29 | from timm.layers.mixed_conv2d import MixedConv2d 30 | from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp 31 | from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn 32 | from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 33 | from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 34 | from timm.layers.padding import get_padding, get_same_padding, pad_same 35 | from timm.layers.patch_embed import PatchEmbed 36 | from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d 37 | from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 38 | from timm.layers.selective_kernel import SelectiveKernel 39 | from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct 40 | from timm.layers.split_attn import SplitAttn 41 | from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 42 | from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 43 | from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool 44 | from timm.layers.trace_utils import _assert, _float_to_int 45 | from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 46 | 47 | import warnings 48 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) 49 | -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | from ._registry import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adafactor_bv import AdafactorBigVision 4 | from .adahessian import Adahessian 5 | from .adamp import AdamP 6 | from .adamw import AdamWLegacy 7 | from .adan import Adan 8 | from .adopt import Adopt 9 | from .lamb import Lamb 10 | from .laprop import LaProp 11 | from .lars import Lars 12 | from .lion import Lion 13 | from .lookahead import Lookahead 14 | from .madgrad import MADGRAD 15 | from .mars import Mars 16 | from .nadam import NAdamLegacy 17 | from .nadamw import NAdamW 18 | from .nvnovograd import NvNovoGrad 19 | from .radam import RAdamLegacy 20 | from .rmsprop_tf import RMSpropTF 21 | from .sgdp import SGDP 22 | from .sgdw import SGDW 23 | 24 | # bring common torch.optim Optimizers into timm.optim namespace for consistency 25 | from torch.optim import Adadelta, Adagrad, Adamax, Adam, AdamW, RMSprop, SGD 26 | try: 27 | # in case any very old torch versions being used 28 | from torch.optim import NAdam, RAdam 29 | except ImportError: 30 | pass 31 | 32 | from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ 33 | create_optimizer_v2, create_optimizer, optimizer_kwargs 34 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers 35 | -------------------------------------------------------------------------------- /timm/optim/_types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, Union, Protocol, Type 2 | try: 3 | from typing import TypeAlias, TypeVar 4 | except ImportError: 5 | from typing_extensions import TypeAlias, TypeVar 6 | 7 | import torch 8 | import torch.optim 9 | 10 | try: 11 | from torch.optim.optimizer import ParamsT 12 | except (ImportError, TypeError): 13 | ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 14 | 15 | 16 | OptimType = Type[torch.optim.Optimizer] 17 | 18 | 19 | class OptimizerCallable(Protocol): 20 | """Protocol for optimizer constructor signatures.""" 21 | 22 | def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ... 23 | 24 | 25 | __all__ = ['ParamsT', 'OptimType', 'OptimizerCallable'] -------------------------------------------------------------------------------- /timm/optim/adamp.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer 14 | import math 15 | 16 | 17 | def _channel_view(x) -> torch.Tensor: 18 | return x.reshape(x.size(0), -1) 19 | 20 | 21 | def _layer_view(x) -> torch.Tensor: 22 | return x.reshape(1, -1) 23 | 24 | 25 | def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float): 26 | wd = 1. 27 | expand_size = (-1,) + (1,) * (len(p.shape) - 1) 28 | for view_func in [_channel_view, _layer_view]: 29 | param_view = view_func(p) 30 | grad_view = view_func(grad) 31 | cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_() 32 | 33 | # FIXME this is a problem for PyTorch XLA 34 | if cosine_sim.max() < delta / math.sqrt(param_view.size(1)): 35 | p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size) 36 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size) 37 | wd = wd_ratio 38 | return perturb, wd 39 | 40 | return perturb, wd 41 | 42 | 43 | class AdamP(Optimizer): 44 | def __init__( 45 | self, 46 | params, 47 | lr=1e-3, 48 | betas=(0.9, 0.999), 49 | eps=1e-8, 50 | weight_decay=0, 51 | delta=0.1, 52 | wd_ratio=0.1, 53 | nesterov=False, 54 | ): 55 | defaults = dict( 56 | lr=lr, 57 | betas=betas, 58 | eps=eps, 59 | weight_decay=weight_decay, 60 | delta=delta, 61 | wd_ratio=wd_ratio, 62 | nesterov=nesterov, 63 | ) 64 | super(AdamP, self).__init__(params, defaults) 65 | 66 | @torch.no_grad() 67 | def step(self, closure=None): 68 | loss = None 69 | if closure is not None: 70 | with torch.enable_grad(): 71 | loss = closure() 72 | 73 | for group in self.param_groups: 74 | for p in group['params']: 75 | if p.grad is None: 76 | continue 77 | 78 | grad = p.grad 79 | beta1, beta2 = group['betas'] 80 | nesterov = group['nesterov'] 81 | 82 | state = self.state[p] 83 | 84 | # State initialization 85 | if len(state) == 0: 86 | state['step'] = 0 87 | state['exp_avg'] = torch.zeros_like(p) 88 | state['exp_avg_sq'] = torch.zeros_like(p) 89 | 90 | # Adam 91 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 92 | 93 | state['step'] += 1 94 | bias_correction1 = 1 - beta1 ** state['step'] 95 | bias_correction2 = 1 - beta2 ** state['step'] 96 | 97 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 98 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 99 | 100 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 101 | step_size = group['lr'] / bias_correction1 102 | 103 | if nesterov: 104 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 105 | else: 106 | perturb = exp_avg / denom 107 | 108 | # Projection 109 | wd_ratio = 1. 110 | if len(p.shape) > 1: 111 | perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) 112 | 113 | # Weight decay 114 | if group['weight_decay'] > 0: 115 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio) 116 | 117 | # Step 118 | p.add_(perturb, alpha=-step_size) 119 | 120 | return loss 121 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | from collections import OrderedDict 8 | from typing import Callable, Dict 9 | 10 | import torch 11 | from torch.optim.optimizer import Optimizer 12 | from collections import defaultdict 13 | 14 | 15 | class Lookahead(Optimizer): 16 | def __init__(self, base_optimizer, alpha=0.5, k=6): 17 | # NOTE super().__init__() not called on purpose 18 | self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() 19 | self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() 20 | if not 0.0 <= alpha <= 1.0: 21 | raise ValueError(f'Invalid slow update rate: {alpha}') 22 | if not 1 <= k: 23 | raise ValueError(f'Invalid lookahead steps: {k}') 24 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 25 | self._base_optimizer = base_optimizer 26 | self.param_groups = base_optimizer.param_groups 27 | self.defaults = base_optimizer.defaults 28 | self.defaults.update(defaults) 29 | self.state = defaultdict(dict) 30 | # manually add our defaults to the param groups 31 | for name, default in defaults.items(): 32 | for group in self._base_optimizer.param_groups: 33 | group.setdefault(name, default) 34 | 35 | @torch.no_grad() 36 | def update_slow(self, group): 37 | for fast_p in group["params"]: 38 | if fast_p.grad is None: 39 | continue 40 | param_state = self._base_optimizer.state[fast_p] 41 | if 'lookahead_slow_buff' not in param_state: 42 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 43 | param_state['lookahead_slow_buff'].copy_(fast_p) 44 | slow = param_state['lookahead_slow_buff'] 45 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 46 | fast_p.copy_(slow) 47 | 48 | def sync_lookahead(self): 49 | for group in self._base_optimizer.param_groups: 50 | self.update_slow(group) 51 | 52 | @torch.no_grad() 53 | def step(self, closure=None): 54 | loss = self._base_optimizer.step(closure) 55 | for group in self._base_optimizer.param_groups: 56 | group['lookahead_step'] += 1 57 | if group['lookahead_step'] % group['lookahead_k'] == 0: 58 | self.update_slow(group) 59 | return loss 60 | 61 | def state_dict(self): 62 | return self._base_optimizer.state_dict() 63 | 64 | def load_state_dict(self, state_dict): 65 | self._base_optimizer.load_state_dict(state_dict) 66 | self.param_groups = self._base_optimizer.param_groups 67 | -------------------------------------------------------------------------------- /timm/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | # lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/ 2 | 3 | from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 4 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group 5 | 6 | import warnings 7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning) 8 | -------------------------------------------------------------------------------- /timm/optim/radam.py: -------------------------------------------------------------------------------- 1 | """RAdam Optimizer. 2 | Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam 3 | Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 4 | 5 | NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference 6 | """ 7 | import math 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | 11 | 12 | class RAdamLegacy(Optimizer): 13 | """ PyTorch RAdam optimizer 14 | 15 | NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference 16 | """ 17 | def __init__( 18 | self, 19 | params, 20 | lr=1e-3, 21 | betas=(0.9, 0.999), 22 | eps=1e-8, 23 | weight_decay=0, 24 | ): 25 | defaults = dict( 26 | lr=lr, 27 | betas=betas, 28 | eps=eps, 29 | weight_decay=weight_decay, 30 | buffer=[[None, None, None] for _ in range(10)] 31 | ) 32 | super(RAdamLegacy, self).__init__(params, defaults) 33 | 34 | def __setstate__(self, state): 35 | super(RAdamLegacy, self).__setstate__(state) 36 | 37 | @torch.no_grad() 38 | def step(self, closure=None): 39 | loss = None 40 | if closure is not None: 41 | with torch.enable_grad(): 42 | loss = closure() 43 | 44 | for group in self.param_groups: 45 | 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.float() 50 | if grad.is_sparse: 51 | raise RuntimeError('RAdam does not support sparse gradients') 52 | 53 | p_fp32 = p.float() 54 | 55 | state = self.state[p] 56 | 57 | if len(state) == 0: 58 | state['step'] = 0 59 | state['exp_avg'] = torch.zeros_like(p_fp32) 60 | state['exp_avg_sq'] = torch.zeros_like(p_fp32) 61 | else: 62 | state['exp_avg'] = state['exp_avg'].type_as(p_fp32) 63 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32) 64 | 65 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 66 | beta1, beta2 = group['betas'] 67 | 68 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 69 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 70 | 71 | state['step'] += 1 72 | buffered = group['buffer'][int(state['step'] % 10)] 73 | if state['step'] == buffered[0]: 74 | num_sma, step_size = buffered[1], buffered[2] 75 | else: 76 | buffered[0] = state['step'] 77 | beta2_t = beta2 ** state['step'] 78 | num_sma_max = 2 / (1 - beta2) - 1 79 | num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 80 | buffered[1] = num_sma 81 | 82 | # more conservative since it's an approximated value 83 | if num_sma >= 5: 84 | step_size = group['lr'] * math.sqrt( 85 | (1 - beta2_t) * 86 | (num_sma - 4) / (num_sma_max - 4) * 87 | (num_sma - 2) / num_sma * 88 | num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) 89 | else: 90 | step_size = group['lr'] / (1 - beta1 ** state['step']) 91 | buffered[2] = step_size 92 | 93 | if group['weight_decay'] != 0: 94 | p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr']) 95 | 96 | # more conservative since it's an approximated value 97 | if num_sma >= 5: 98 | denom = exp_avg_sq.sqrt().add_(group['eps']) 99 | p_fp32.addcdiv_(exp_avg, denom, value=-step_size) 100 | else: 101 | p_fp32.add_(exp_avg, alpha=-step_size) 102 | 103 | p.copy_(p_fp32) 104 | 105 | return loss 106 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__( 21 | self, 22 | params, 23 | lr=required, 24 | momentum=0, 25 | dampening=0, 26 | weight_decay=0, 27 | nesterov=False, 28 | eps=1e-8, 29 | delta=0.1, 30 | wd_ratio=0.1 31 | ): 32 | defaults = dict( 33 | lr=lr, 34 | momentum=momentum, 35 | dampening=dampening, 36 | weight_decay=weight_decay, 37 | nesterov=nesterov, 38 | eps=eps, 39 | delta=delta, 40 | wd_ratio=wd_ratio, 41 | ) 42 | super(SGDP, self).__init__(params, defaults) 43 | 44 | @torch.no_grad() 45 | def step(self, closure=None): 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | weight_decay = group['weight_decay'] 53 | momentum = group['momentum'] 54 | dampening = group['dampening'] 55 | nesterov = group['nesterov'] 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | grad = p.grad 61 | state = self.state[p] 62 | 63 | # State initialization 64 | if len(state) == 0: 65 | state['momentum'] = torch.zeros_like(p) 66 | 67 | # SGD 68 | buf = state['momentum'] 69 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 70 | if nesterov: 71 | d_p = grad + momentum * buf 72 | else: 73 | d_p = buf 74 | 75 | # Projection 76 | wd_ratio = 1. 77 | if len(p.shape) > 1: 78 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 79 | 80 | # Weight decay 81 | if weight_decay != 0: 82 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 83 | 84 | # Step 85 | p.add_(d_p, alpha=-group['lr']) 86 | 87 | return loss 88 | -------------------------------------------------------------------------------- /timm/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/pytorch-image-models/a22366e3ce52568193bd49d64f4e88fb01796965/timm/py.typed -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs 9 | -------------------------------------------------------------------------------- /timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | decay_t: List[int], 18 | decay_rate: float = 1., 19 | warmup_t=0, 20 | warmup_lr_init=0, 21 | warmup_prefix=True, 22 | t_in_epochs=True, 23 | noise_range_t=None, 24 | noise_pct=0.67, 25 | noise_std=1.0, 26 | noise_seed=42, 27 | initialize=True, 28 | ) -> None: 29 | super().__init__( 30 | optimizer, 31 | param_group_field="lr", 32 | t_in_epochs=t_in_epochs, 33 | noise_range_t=noise_range_t, 34 | noise_pct=noise_pct, 35 | noise_std=noise_std, 36 | noise_seed=noise_seed, 37 | initialize=initialize, 38 | ) 39 | 40 | self.decay_t = decay_t 41 | self.decay_rate = decay_rate 42 | self.warmup_t = warmup_t 43 | self.warmup_lr_init = warmup_lr_init 44 | self.warmup_prefix = warmup_prefix 45 | if self.warmup_t: 46 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 47 | super().update_groups(self.warmup_lr_init) 48 | else: 49 | self.warmup_steps = [1 for _ in self.base_values] 50 | 51 | def get_curr_decay_steps(self, t): 52 | # find where in the array t goes, 53 | # assumes self.decay_t is sorted 54 | return bisect.bisect_right(self.decay_t, t + 1) 55 | 56 | def _get_lr(self, t: int) -> List[float]: 57 | if t < self.warmup_t: 58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 59 | else: 60 | if self.warmup_prefix: 61 | t = t - self.warmup_t 62 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 63 | return lrs 64 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from typing import List 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class PlateauLRScheduler(Scheduler): 14 | """Decay the LR by a factor every time the validation loss plateaus.""" 15 | 16 | def __init__( 17 | self, 18 | optimizer, 19 | decay_rate=0.1, 20 | patience_t=10, 21 | verbose=True, 22 | threshold=1e-4, 23 | cooldown_t=0, 24 | warmup_t=0, 25 | warmup_lr_init=0, 26 | lr_min=0, 27 | mode='max', 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize=True, 34 | ): 35 | super().__init__( 36 | optimizer, 37 | 'lr', 38 | noise_range_t=noise_range_t, 39 | noise_type=noise_type, 40 | noise_pct=noise_pct, 41 | noise_std=noise_std, 42 | noise_seed=noise_seed, 43 | initialize=initialize, 44 | ) 45 | 46 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 47 | self.optimizer, 48 | patience=patience_t, 49 | factor=decay_rate, 50 | verbose=verbose, 51 | threshold=threshold, 52 | cooldown=cooldown_t, 53 | mode=mode, 54 | min_lr=lr_min 55 | ) 56 | 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | if self.warmup_t: 60 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 61 | super().update_groups(self.warmup_lr_init) 62 | else: 63 | self.warmup_steps = [1 for _ in self.base_values] 64 | self.restore_lr = None 65 | 66 | def state_dict(self): 67 | return { 68 | 'best': self.lr_scheduler.best, 69 | 'last_epoch': self.lr_scheduler.last_epoch, 70 | } 71 | 72 | def load_state_dict(self, state_dict): 73 | self.lr_scheduler.best = state_dict['best'] 74 | if 'last_epoch' in state_dict: 75 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 76 | 77 | # override the base class step fn completely 78 | def step(self, epoch, metric=None): 79 | if epoch <= self.warmup_t: 80 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 81 | super().update_groups(lrs) 82 | else: 83 | if self.restore_lr is not None: 84 | # restore actual LR from before our last noise perturbation before stepping base 85 | for i, param_group in enumerate(self.optimizer.param_groups): 86 | param_group['lr'] = self.restore_lr[i] 87 | self.restore_lr = None 88 | 89 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 90 | 91 | if self._is_apply_noise(epoch): 92 | self._apply_noise(epoch) 93 | 94 | def step_update(self, num_updates: int, metric: float = None): 95 | return None 96 | 97 | def _apply_noise(self, epoch): 98 | noise = self._calculate_noise(epoch) 99 | 100 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 101 | # stepping of base scheduler 102 | restore_lr = [] 103 | for i, param_group in enumerate(self.optimizer.param_groups): 104 | old_lr = float(param_group['lr']) 105 | restore_lr.append(old_lr) 106 | new_lr = old_lr + old_lr * noise 107 | param_group['lr'] = new_lr 108 | self.restore_lr = restore_lr 109 | 110 | def _get_lr(self, t: int) -> List[float]: 111 | assert False, 'should not be called as step is overridden' 112 | -------------------------------------------------------------------------------- /timm/scheduler/poly_lr.py: -------------------------------------------------------------------------------- 1 | """ Polynomial Scheduler 2 | 3 | Polynomial LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import math 8 | import logging 9 | from typing import List 10 | 11 | import torch 12 | 13 | from .scheduler import Scheduler 14 | 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class PolyLRScheduler(Scheduler): 20 | """ Polynomial LR Scheduler w/ warmup, noise, and k-decay 21 | 22 | k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 23 | """ 24 | 25 | def __init__( 26 | self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | power: float = 0.5, 30 | lr_min: float = 0., 31 | cycle_mul: float = 1., 32 | cycle_decay: float = 1., 33 | cycle_limit: int = 1, 34 | warmup_t=0, 35 | warmup_lr_init=0, 36 | warmup_prefix=False, 37 | t_in_epochs=True, 38 | noise_range_t=None, 39 | noise_pct=0.67, 40 | noise_std=1.0, 41 | noise_seed=42, 42 | k_decay=1.0, 43 | initialize=True, 44 | ) -> None: 45 | super().__init__( 46 | optimizer, 47 | param_group_field="lr", 48 | t_in_epochs=t_in_epochs, 49 | noise_range_t=noise_range_t, 50 | noise_pct=noise_pct, 51 | noise_std=noise_std, 52 | noise_seed=noise_seed, 53 | initialize=initialize 54 | ) 55 | 56 | assert t_initial > 0 57 | assert lr_min >= 0 58 | if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: 59 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 60 | "rate since t_initial = t_mul = eta_mul = 1.") 61 | self.t_initial = t_initial 62 | self.power = power 63 | self.lr_min = lr_min 64 | self.cycle_mul = cycle_mul 65 | self.cycle_decay = cycle_decay 66 | self.cycle_limit = cycle_limit 67 | self.warmup_t = warmup_t 68 | self.warmup_lr_init = warmup_lr_init 69 | self.warmup_prefix = warmup_prefix 70 | self.k_decay = k_decay 71 | if self.warmup_t: 72 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 73 | super().update_groups(self.warmup_lr_init) 74 | else: 75 | self.warmup_steps = [1 for _ in self.base_values] 76 | 77 | def _get_lr(self, t: int) -> List[float]: 78 | if t < self.warmup_t: 79 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 80 | else: 81 | if self.warmup_prefix: 82 | t = t - self.warmup_t 83 | 84 | if self.cycle_mul != 1: 85 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) 86 | t_i = self.cycle_mul ** i * self.t_initial 87 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial 88 | else: 89 | i = t // self.t_initial 90 | t_i = self.t_initial 91 | t_curr = t - (self.t_initial * i) 92 | 93 | gamma = self.cycle_decay ** i 94 | lr_max_values = [v * gamma for v in self.base_values] 95 | k = self.k_decay 96 | 97 | if i < self.cycle_limit: 98 | lrs = [ 99 | self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power 100 | for lr_max in lr_max_values 101 | ] 102 | else: 103 | lrs = [self.lr_min for _ in self.base_values] 104 | 105 | return lrs 106 | 107 | def get_cycle_length(self, cycles=0): 108 | cycles = max(1, cycles or self.cycle_limit) 109 | if self.cycle_mul == 1.0: 110 | t = self.t_initial * cycles 111 | else: 112 | t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) 113 | return t + self.warmup_t if self.warmup_prefix else t 114 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | from typing import List 10 | 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | class StepLRScheduler(Scheduler): 16 | """ 17 | """ 18 | 19 | def __init__( 20 | self, 21 | optimizer: torch.optim.Optimizer, 22 | decay_t: float, 23 | decay_rate: float = 1., 24 | warmup_t=0, 25 | warmup_lr_init=0, 26 | warmup_prefix=True, 27 | t_in_epochs=True, 28 | noise_range_t=None, 29 | noise_pct=0.67, 30 | noise_std=1.0, 31 | noise_seed=42, 32 | initialize=True, 33 | ) -> None: 34 | super().__init__( 35 | optimizer, 36 | param_group_field="lr", 37 | t_in_epochs=t_in_epochs, 38 | noise_range_t=noise_range_t, 39 | noise_pct=noise_pct, 40 | noise_std=noise_std, 41 | noise_seed=noise_seed, 42 | initialize=initialize, 43 | ) 44 | 45 | self.decay_t = decay_t 46 | self.decay_rate = decay_rate 47 | self.warmup_t = warmup_t 48 | self.warmup_lr_init = warmup_lr_init 49 | self.warmup_prefix = warmup_prefix 50 | if self.warmup_t: 51 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 52 | super().update_groups(self.warmup_lr_init) 53 | else: 54 | self.warmup_steps = [1 for _ in self.base_values] 55 | 56 | def _get_lr(self, t: int) -> List[float]: 57 | if t < self.warmup_t: 58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 59 | else: 60 | if self.warmup_prefix: 61 | t = t - self.warmup_t 62 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 63 | return lrs 64 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | from typing import List 12 | 13 | from .scheduler import Scheduler 14 | 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class TanhLRScheduler(Scheduler): 20 | """ 21 | Hyberbolic-Tangent decay with restarts. 22 | This is described in the paper https://arxiv.org/abs/1806.01593 23 | """ 24 | 25 | def __init__( 26 | self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | lb: float = -7., 30 | ub: float = 3., 31 | lr_min: float = 0., 32 | cycle_mul: float = 1., 33 | cycle_decay: float = 1., 34 | cycle_limit: int = 1, 35 | warmup_t=0, 36 | warmup_lr_init=0, 37 | warmup_prefix=False, 38 | t_in_epochs=True, 39 | noise_range_t=None, 40 | noise_pct=0.67, 41 | noise_std=1.0, 42 | noise_seed=42, 43 | initialize=True, 44 | ) -> None: 45 | super().__init__( 46 | optimizer, 47 | param_group_field="lr", 48 | t_in_epochs=t_in_epochs, 49 | noise_range_t=noise_range_t, 50 | noise_pct=noise_pct, 51 | noise_std=noise_std, 52 | noise_seed=noise_seed, 53 | initialize=initialize, 54 | ) 55 | 56 | assert t_initial > 0 57 | assert lr_min >= 0 58 | assert lb < ub 59 | assert cycle_limit >= 0 60 | assert warmup_t >= 0 61 | assert warmup_lr_init >= 0 62 | self.lb = lb 63 | self.ub = ub 64 | self.t_initial = t_initial 65 | self.lr_min = lr_min 66 | self.cycle_mul = cycle_mul 67 | self.cycle_decay = cycle_decay 68 | self.cycle_limit = cycle_limit 69 | self.warmup_t = warmup_t 70 | self.warmup_lr_init = warmup_lr_init 71 | self.warmup_prefix = warmup_prefix 72 | if self.warmup_t: 73 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 74 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 75 | super().update_groups(self.warmup_lr_init) 76 | else: 77 | self.warmup_steps = [1 for _ in self.base_values] 78 | 79 | def _get_lr(self, t: int) -> List[float]: 80 | if t < self.warmup_t: 81 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 82 | else: 83 | if self.warmup_prefix: 84 | t = t - self.warmup_t 85 | 86 | if self.cycle_mul != 1: 87 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) 88 | t_i = self.cycle_mul ** i * self.t_initial 89 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial 90 | else: 91 | i = t // self.t_initial 92 | t_i = self.t_initial 93 | t_curr = t - (self.t_initial * i) 94 | 95 | if i < self.cycle_limit: 96 | gamma = self.cycle_decay ** i 97 | lr_max_values = [v * gamma for v in self.base_values] 98 | 99 | tr = t_curr / t_i 100 | lrs = [ 101 | self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 102 | for lr_max in lr_max_values 103 | ] 104 | else: 105 | lrs = [self.lr_min for _ in self.base_values] 106 | return lrs 107 | 108 | def get_cycle_length(self, cycles=0): 109 | cycles = max(1, cycles or self.cycle_limit) 110 | if self.cycle_mul == 1.0: 111 | t = self.t_initial * cycles 112 | else: 113 | t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) 114 | return t + self.warmup_t if self.warmup_prefix else t 115 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .attention_extract import AttentionExtract 3 | from .checkpoint_saver import CheckpointSaver 4 | from .clip_grad import dispatch_clip_grad 5 | from .cuda import ApexScaler, NativeScaler 6 | from .decay_batch import decay_batch_step, check_batch_size_retry 7 | from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ 8 | world_info_from_env, is_distributed_env, is_primary 9 | from .jit import set_jit_legacy, set_jit_fuser 10 | from .log import setup_default_logging, FormatterNoInfo 11 | from .metrics import AverageMeter, accuracy 12 | from .misc import natural_key, add_bool_arg, ParseKwargs 13 | from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model 14 | from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 15 | from .random import random_seed 16 | from .summary import update_summary, get_outdir 17 | -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/utils/attention_extract.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import re 3 | from collections import OrderedDict 4 | from typing import Union, Optional, List 5 | 6 | import torch 7 | 8 | 9 | class AttentionExtract(torch.nn.Module): 10 | # defaults should cover a significant number of timm models with attention maps. 11 | default_node_names = ['*attn.softmax'] 12 | default_module_names = ['*attn_drop'] 13 | 14 | def __init__( 15 | self, 16 | model: Union[torch.nn.Module], 17 | names: Optional[List[str]] = None, 18 | mode: str = 'eval', 19 | method: str = 'fx', 20 | hook_type: str = 'forward', 21 | use_regex: bool = False, 22 | ): 23 | """ Extract attention maps (or other activations) from a model by name. 24 | 25 | Args: 26 | model: Instantiated model to extract from. 27 | names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks. 28 | mode: 'train' or 'eval' model mode. 29 | method: 'fx' or 'hook' extraction method. 30 | hook_type: 'forward' or 'forward_pre' hooks used. 31 | use_regex: Use regex instead of fnmatch 32 | """ 33 | super().__init__() 34 | assert mode in ('train', 'eval') 35 | if mode == 'train': 36 | model = model.train() 37 | else: 38 | model = model.eval() 39 | 40 | assert method in ('fx', 'hook') 41 | if method == 'fx': 42 | # names are activation node names 43 | from timm.models._features_fx import get_graph_node_names, GraphExtractNet 44 | 45 | node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] 46 | names = names or self.default_node_names 47 | if use_regex: 48 | regexes = [re.compile(r) for r in names] 49 | matched = [g for g in node_names if any([r.match(g) for r in regexes])] 50 | else: 51 | matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] 52 | if not matched: 53 | raise RuntimeError(f'No node names found matching {names}.') 54 | 55 | self.model = GraphExtractNet(model, matched, return_dict=True) 56 | self.hooks = None 57 | else: 58 | # names are module names 59 | assert hook_type in ('forward', 'forward_pre') 60 | from timm.models._features import FeatureHooks 61 | 62 | module_names = [n for n, m in model.named_modules()] 63 | names = names or self.default_module_names 64 | if use_regex: 65 | regexes = [re.compile(r) for r in names] 66 | matched = [m for m in module_names if any([r.match(m) for r in regexes])] 67 | else: 68 | matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] 69 | if not matched: 70 | raise RuntimeError(f'No module names found matching {names}.') 71 | 72 | self.model = model 73 | self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type) 74 | 75 | self.names = matched 76 | self.mode = mode 77 | self.method = method 78 | 79 | def forward(self, x): 80 | if self.hooks is not None: 81 | self.model(x) 82 | output = self.hooks.get_output(device=x.device) 83 | else: 84 | output = self.model(x) 85 | return output 86 | -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__( 21 | self, 22 | loss, 23 | optimizer, 24 | clip_grad=None, 25 | clip_mode='norm', 26 | parameters=None, 27 | create_graph=False, 28 | need_update=True, 29 | ): 30 | with amp.scale_loss(loss, optimizer) as scaled_loss: 31 | scaled_loss.backward(create_graph=create_graph) 32 | if need_update: 33 | if clip_grad is not None: 34 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 35 | optimizer.step() 36 | 37 | def state_dict(self): 38 | if 'state_dict' in amp.__dict__: 39 | return amp.state_dict() 40 | 41 | def load_state_dict(self, state_dict): 42 | if 'load_state_dict' in amp.__dict__: 43 | amp.load_state_dict(state_dict) 44 | 45 | 46 | class NativeScaler: 47 | state_dict_key = "amp_scaler" 48 | 49 | def __init__(self, device='cuda'): 50 | try: 51 | self._scaler = torch.amp.GradScaler(device=device) 52 | except (AttributeError, TypeError) as e: 53 | self._scaler = torch.cuda.amp.GradScaler() 54 | 55 | def __call__( 56 | self, 57 | loss, 58 | optimizer, 59 | clip_grad=None, 60 | clip_mode='norm', 61 | parameters=None, 62 | create_graph=False, 63 | need_update=True, 64 | ): 65 | self._scaler.scale(loss).backward(create_graph=create_graph) 66 | if need_update: 67 | if clip_grad is not None: 68 | assert parameters is not None 69 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 70 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 71 | self._scaler.step(optimizer) 72 | self._scaler.update() 73 | 74 | def state_dict(self): 75 | return self._scaler.state_dict() 76 | 77 | def load_state_dict(self, state_dict): 78 | self._scaler.load_state_dict(state_dict) 79 | -------------------------------------------------------------------------------- /timm/utils/decay_batch.py: -------------------------------------------------------------------------------- 1 | """ Batch size decay and retry helpers. 2 | 3 | Copyright 2022 Ross Wightman 4 | """ 5 | import math 6 | 7 | 8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): 9 | """ power of two batch-size decay with intra steps 10 | 11 | Decay by stepping between powers of 2: 12 | * determine power-of-2 floor of current batch size (base batch size) 13 | * divide above value by num_intra_steps to determine step size 14 | * floor batch_size to nearest multiple of step_size (from base batch size) 15 | Examples: 16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 20 | """ 21 | if batch_size <= 1: 22 | # return 0 for stopping value so easy to use in loop 23 | return 0 24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) 25 | step_size = max(base_batch_size // num_intra_steps, 1) 26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size 27 | if no_odd and batch_size % 2: 28 | batch_size -= 1 29 | return batch_size 30 | 31 | 32 | def check_batch_size_retry(error_str): 33 | """ check failure error string for conditions where batch decay retry should not be attempted 34 | """ 35 | error_str = error_str.lower() 36 | if 'required rank' in error_str: 37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's 38 | # not compatible with channels_last memory format. 39 | return False 40 | if 'illegal' in error_str: 41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state 42 | return False 43 | return True 44 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | 10 | def set_jit_legacy(): 11 | """ Set JIT executor to legacy w/ support for op fusion 12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 13 | in the JIT executor. These API are not supported so could change. 14 | """ 15 | # 16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 17 | torch._C._jit_set_profiling_executor(False) 18 | torch._C._jit_set_profiling_mode(False) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | #torch._C._jit_set_texpr_fuser_enabled(True) 21 | 22 | 23 | def set_jit_fuser(fuser): 24 | if fuser == "te": 25 | # default fuser should be == 'te' 26 | torch._C._jit_set_profiling_executor(True) 27 | torch._C._jit_set_profiling_mode(True) 28 | torch._C._jit_override_can_fuse_on_cpu(False) 29 | torch._C._jit_override_can_fuse_on_gpu(True) 30 | torch._C._jit_set_texpr_fuser_enabled(True) 31 | try: 32 | torch._C._jit_set_nvfuser_enabled(False) 33 | except Exception: 34 | pass 35 | elif fuser == "old" or fuser == "legacy": 36 | torch._C._jit_set_profiling_executor(False) 37 | torch._C._jit_set_profiling_mode(False) 38 | torch._C._jit_override_can_fuse_on_gpu(True) 39 | torch._C._jit_set_texpr_fuser_enabled(False) 40 | try: 41 | torch._C._jit_set_nvfuser_enabled(False) 42 | except Exception: 43 | pass 44 | elif fuser == "nvfuser" or fuser == "nvf": 45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' 46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' 47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' 48 | torch._C._jit_set_texpr_fuser_enabled(False) 49 | torch._C._jit_set_profiling_executor(True) 50 | torch._C._jit_set_profiling_mode(True) 51 | torch._C._jit_can_fuse_on_cpu() 52 | torch._C._jit_can_fuse_on_gpu() 53 | torch._C._jit_override_can_fuse_on_cpu(False) 54 | torch._C._jit_override_can_fuse_on_gpu(False) 55 | torch._C._jit_set_nvfuser_guard_mode(True) 56 | torch._C._jit_set_nvfuser_enabled(True) 57 | else: 58 | assert False, f"Invalid jit fuser ({fuser})" 59 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import argparse 6 | import ast 7 | import re 8 | 9 | 10 | def natural_key(string_): 11 | """See http://www.codinghorror.com/blog/archives/001018.html""" 12 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 13 | 14 | 15 | def add_bool_arg(parser, name, default=False, help=''): 16 | dest_name = name.replace('-', '_') 17 | group = parser.add_mutually_exclusive_group(required=False) 18 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 19 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 20 | parser.set_defaults(**{dest_name: default}) 21 | 22 | 23 | class ParseKwargs(argparse.Action): 24 | def __call__(self, parser, namespace, values, option_string=None): 25 | kw = {} 26 | for value in values: 27 | key, value = value.split('=') 28 | try: 29 | kw[key] = ast.literal_eval(value) 30 | except ValueError: 31 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 32 | setattr(namespace, self.dest, kw) 33 | -------------------------------------------------------------------------------- /timm/utils/onnx.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List 2 | 3 | import torch 4 | 5 | 6 | def onnx_forward(onnx_file, example_input): 7 | import onnxruntime 8 | 9 | sess_options = onnxruntime.SessionOptions() 10 | session = onnxruntime.InferenceSession(onnx_file, sess_options) 11 | input_name = session.get_inputs()[0].name 12 | output = session.run([], {input_name: example_input.numpy()}) 13 | output = output[0] 14 | return output 15 | 16 | 17 | def onnx_export( 18 | model: torch.nn.Module, 19 | output_file: str, 20 | example_input: Optional[torch.Tensor] = None, 21 | training: bool = False, 22 | verbose: bool = False, 23 | check: bool = True, 24 | check_forward: bool = False, 25 | batch_size: int = 64, 26 | input_size: Tuple[int, int, int] = None, 27 | opset: Optional[int] = None, 28 | dynamic_size: bool = False, 29 | aten_fallback: bool = False, 30 | keep_initializers: Optional[bool] = None, 31 | use_dynamo: bool = False, 32 | input_names: List[str] = None, 33 | output_names: List[str] = None, 34 | ): 35 | import onnx 36 | 37 | if training: 38 | training_mode = torch.onnx.TrainingMode.TRAINING 39 | model.train() 40 | else: 41 | training_mode = torch.onnx.TrainingMode.EVAL 42 | model.eval() 43 | 44 | if example_input is None: 45 | if not input_size: 46 | assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided' 47 | input_size = model.default_cfg.get('input_size') 48 | example_input = torch.randn((batch_size,) + input_size, requires_grad=training) 49 | 50 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means 51 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for 52 | # the input img_size specified in this script. 53 | 54 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to 55 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit 56 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... 57 | with torch.no_grad(): 58 | original_out = model(example_input) 59 | 60 | input_names = input_names or ["input0"] 61 | output_names = output_names or ["output0"] 62 | 63 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} 64 | if dynamic_size: 65 | dynamic_axes['input0'][2] = 'height' 66 | dynamic_axes['input0'][3] = 'width' 67 | 68 | if aten_fallback: 69 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 70 | else: 71 | export_type = torch.onnx.OperatorExportTypes.ONNX 72 | 73 | if use_dynamo: 74 | export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size) 75 | export_output = torch.onnx.dynamo_export( 76 | model, 77 | example_input, 78 | export_options=export_options, 79 | ) 80 | export_output.save(output_file) 81 | else: 82 | torch.onnx.export( 83 | model, 84 | example_input, 85 | output_file, 86 | training=training_mode, 87 | export_params=True, 88 | verbose=verbose, 89 | input_names=input_names, 90 | output_names=output_names, 91 | keep_initializers_as_inputs=keep_initializers, 92 | dynamic_axes=dynamic_axes, 93 | opset_version=opset, 94 | operator_export_type=export_type 95 | ) 96 | 97 | if check: 98 | onnx_model = onnx.load(output_file) 99 | onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error 100 | if check_forward and not training: 101 | import numpy as np 102 | onnx_out = onnx_forward(output_file, example_input) 103 | np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3) 104 | 105 | -------------------------------------------------------------------------------- /timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | 14 | def get_outdir(path, *paths, inc=False): 15 | outdir = os.path.join(path, *paths) 16 | if not os.path.exists(outdir): 17 | os.makedirs(outdir) 18 | elif inc: 19 | count = 1 20 | outdir_inc = outdir + '-' + str(count) 21 | while os.path.exists(outdir_inc): 22 | count = count + 1 23 | outdir_inc = outdir + '-' + str(count) 24 | assert count < 100 25 | outdir = outdir_inc 26 | os.makedirs(outdir) 27 | return outdir 28 | 29 | 30 | def update_summary( 31 | epoch, 32 | train_metrics, 33 | eval_metrics, 34 | filename, 35 | lr=None, 36 | write_header=False, 37 | log_wandb=False, 38 | ): 39 | rowd = OrderedDict(epoch=epoch) 40 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 41 | if eval_metrics: 42 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 43 | if lr is not None: 44 | rowd['lr'] = lr 45 | if log_wandb: 46 | wandb.log(rowd) 47 | with open(filename, mode='a') as cf: 48 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 49 | if write_header: # first iteration (epoch == 1 can't be used) 50 | dw.writeheader() 51 | dw.writerow(rowd) 52 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.15' 2 | --------------------------------------------------------------------------------