├── tests ├── __init__.py ├── test_layers.py └── test_utils.py ├── docs ├── models │ ├── .pages │ └── .templates │ │ ├── generate_readmes.py │ │ ├── models │ │ ├── spnasnet.md │ │ ├── gloun-senet.md │ │ ├── nasnet.md │ │ ├── gloun-xception.md │ │ ├── legacy-senet.md │ │ ├── inception-v4.md │ │ ├── skresnext.md │ │ ├── pnasnet.md │ │ ├── inception-resnet-v2.md │ │ ├── csp-resnet.md │ │ ├── csp-resnext.md │ │ ├── fbnet.md │ │ ├── res2next.md │ │ ├── csp-darknet.md │ │ ├── gloun-inception-v3.md │ │ ├── inception-v3.md │ │ ├── ese-vovnet.md │ │ ├── tf-inception-v3.md │ │ ├── wide-resnet.md │ │ └── ensemble-adversarial.md │ │ └── code_snippets.md ├── javascripts │ └── tables.js ├── scripts.md └── index.md ├── timm ├── version.py ├── data │ ├── readers │ │ ├── __init__.py │ │ ├── shared_count.py │ │ ├── reader.py │ │ ├── class_map.py │ │ ├── reader_factory.py │ │ ├── img_extensions.py │ │ ├── reader_hfds.py │ │ ├── reader_image_tar.py │ │ └── reader_image_folder.py │ ├── constants.py │ ├── __init__.py │ ├── _info │ │ ├── imagenet_r_indices.txt │ │ ├── imagenet_a_indices.txt │ │ ├── imagenet_a_synsets.txt │ │ └── imagenet_r_synsets.txt │ ├── real_labels.py │ └── dataset_info.py ├── models │ ├── hub.py │ ├── factory.py │ ├── features.py │ ├── registry.py │ ├── fx_features.py │ ├── helpers.py │ ├── layers │ │ └── __init__.py │ └── __init__.py ├── layers │ ├── typing.py │ ├── trace_utils.py │ ├── linear.py │ ├── helpers.py │ ├── format.py │ ├── grn.py │ ├── blur_pool.py │ ├── create_conv2d.py │ ├── create_norm.py │ ├── median_pool.py │ ├── patch_dropout.py │ ├── space_to_depth.py │ ├── mixed_conv2d.py │ ├── test_time_pool.py │ ├── interpolate.py │ ├── global_context.py │ ├── filter_response_norm.py │ ├── activations_jit.py │ ├── pos_embed.py │ ├── separable_conv.py │ ├── padding.py │ ├── pool2d_same.py │ ├── split_attn.py │ └── inplace_abn.py ├── utils │ ├── random.py │ ├── __init__.py │ ├── clip_grad.py │ ├── metrics.py │ ├── log.py │ ├── misc.py │ ├── summary.py │ ├── agc.py │ ├── decay_batch.py │ ├── cuda.py │ └── jit.py ├── loss │ ├── __init__.py │ ├── cross_entropy.py │ ├── jsd.py │ └── binary_cross_entropy.py ├── __init__.py ├── scheduler │ ├── __init__.py │ ├── step_lr.py │ └── multistep_lr.py └── optim │ ├── __init__.py │ ├── sgdp.py │ └── lookahead.py ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── delete_doc_comment_trigger.yml │ ├── delete_doc_comment.yml │ ├── upload_pr_documentation.yml │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ └── tests.yml ├── requirements.txt ├── requirements-dev.txt ├── MANIFEST.in ├── hubconf.py ├── setup.cfg ├── distributed_train.sh ├── hfdocs ├── source │ ├── reference │ │ ├── models.mdx │ │ ├── data.mdx │ │ ├── schedulers.mdx │ │ └── optimizers.mdx │ ├── index.mdx │ ├── hf_hub.mdx │ └── installation.mdx └── README.md ├── requirements-docs.txt ├── pyproject.toml ├── model-index.yml ├── mkdocs.yml ├── .gitignore ├── setup.py └── results └── generate_csv_results.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/models/.pages: -------------------------------------------------------------------------------- 1 | title: Model Pages -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.9.11' 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: rwightman 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision 3 | pyyaml 4 | huggingface_hub 5 | safetensors>=0.2 -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-timeout 3 | pytest-xdist 4 | pytest-forked 5 | expecttest 6 | -------------------------------------------------------------------------------- /timm/data/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader_factory import create_reader 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include timm/models/_pruned/*.txt 2 | include timm/data/_info/*.txt 3 | include timm/data/_info/*.json 4 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | import timm 3 | globals().update(timm.models._registry._model_entrypoints) 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [dist_conda] 2 | 3 | conda_name_differences = 'torch:pytorch' 4 | channels = pytorch 5 | noarch = True 6 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /hfdocs/source/reference/models.mdx: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | [[autodoc]] timm.create_model 4 | 5 | [[autodoc]] timm.list_models 6 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | mkdocs-redirects 4 | mdx_truly_sane_lists 5 | mkdocs-awesome-pages-plugin 6 | -------------------------------------------------------------------------------- /timm/models/hub.py: -------------------------------------------------------------------------------- 1 | from ._hub import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 5 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from ._factory import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 5 | -------------------------------------------------------------------------------- /timm/models/features.py: -------------------------------------------------------------------------------- 1 | from ._features import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 5 | -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | from ._registry import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 5 | -------------------------------------------------------------------------------- /timm/models/fx_features.py: -------------------------------------------------------------------------------- 1 | from ._features_fx import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 5 | -------------------------------------------------------------------------------- /docs/javascripts/tables.js: -------------------------------------------------------------------------------- 1 | app.location$.subscribe(function() { 2 | var tables = document.querySelectorAll("article table") 3 | tables.forEach(function(table) { 4 | new Tablesort(table) 5 | }) 6 | }) -------------------------------------------------------------------------------- /timm/layers/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Type, Union 2 | 3 | import torch 4 | 5 | 6 | LayerType = Union[str, Callable, Type[torch.nn.Module]] 7 | PadType = Union[str, int, Tuple[int, int]] 8 | -------------------------------------------------------------------------------- /hfdocs/source/reference/data.mdx: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | [[autodoc]] timm.data.create_dataset 4 | 5 | [[autodoc]] timm.data.create_loader 6 | 7 | [[autodoc]] timm.data.create_transform 8 | 9 | [[autodoc]] timm.data.resolve_data_config -------------------------------------------------------------------------------- /timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /timm/models/helpers.py: -------------------------------------------------------------------------------- 1 | from ._builder import * 2 | from ._helpers import * 3 | from ._manipulate import * 4 | from ._prune import * 5 | 6 | import warnings 7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) 8 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Discussions 4 | url: https://github.com/rwightman/pytorch-image-models/discussions 5 | about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions. 6 | -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable 3 | from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ 4 | is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value 5 | -------------------------------------------------------------------------------- /hfdocs/README.md: -------------------------------------------------------------------------------- 1 | # Hugging Face Timm Docs 2 | 3 | ## Getting Started 4 | 5 | ``` 6 | pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder 7 | pip install watchdog black 8 | ``` 9 | 10 | ## Preview the Docs Locally 11 | 12 | ``` 13 | doc-builder preview timm hfdocs/source 14 | ``` 15 | -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment_trigger.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment trigger 2 | 3 | on: 4 | pull_request: 5 | types: [ closed ] 6 | 7 | 8 | jobs: 9 | delete: 10 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main 11 | with: 12 | pr_number: ${{ github.event.number }} -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs 9 | -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Delete doc comment trigger"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | delete: 11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main 12 | secrets: 13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /timm/data/readers/shared_count.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value 2 | 3 | 4 | class SharedCount: 5 | def __init__(self, epoch: int = 0): 6 | self.shared_epoch = Value('i', epoch) 7 | 8 | @property 9 | def value(self): 10 | return self.shared_epoch.value 11 | 12 | @value.setter 13 | def value(self, epoch): 14 | self.shared_epoch.value = epoch 15 | -------------------------------------------------------------------------------- /timm/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: timm 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | DEFAULT_CROP_MODE = 'center' 3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 11 | -------------------------------------------------------------------------------- /timm/data/readers/reader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Reader: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: pytorch-image-models 16 | package_name: timm 17 | path_to_docs: pytorch-image-models/hfdocs/source 18 | version_tag_suffix: "" 19 | secrets: 20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | markers = [ 3 | "base: marker for model tests using the basic setup", 4 | "cfg: marker for model tests checking the config", 5 | "torchscript: marker for model tests using torchscript", 6 | "features: marker for model tests checking feature extraction", 7 | "fxforward: marker for model tests using torch fx (only forward)", 8 | "fxbackward: marker for model tests using torch fx (only backward)", 9 | ] 10 | 11 | [tool.black] 12 | line-length = 120 13 | target-version = ['py37', 'py38', 'py39', 'py310', 'py311'] 14 | skip-string-normalization = true 15 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adahessian import Adahessian 4 | from .adamp import AdamP 5 | from .adamw import AdamW 6 | from .adan import Adan 7 | from .lamb import Lamb 8 | from .lars import Lars 9 | from .lookahead import Lookahead 10 | from .madgrad import MADGRAD 11 | from .nadam import Nadam 12 | from .nvnovograd import NvNovoGrad 13 | from .radam import RAdam 14 | from .rmsprop_tf import RMSpropTF 15 | from .sgdp import SGDP 16 | from .lion import Lion 17 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 18 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: pytorch-image-models 17 | package_name: timm 18 | path_to_docs: pytorch-image-models/hfdocs/source 19 | version_tag_suffix: "" 20 | -------------------------------------------------------------------------------- /hfdocs/source/reference/schedulers.mdx: -------------------------------------------------------------------------------- 1 | # Learning Rate Schedulers 2 | 3 | This page contains the API reference documentation for learning rate schedulers included in `timm`. 4 | 5 | ## Schedulers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler 10 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2 11 | 12 | ### Scheduler Classes 13 | 14 | [[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler 15 | [[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler 16 | [[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler 17 | [[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler 18 | [[autodoc]] timm.scheduler.step_lr.StepLRScheduler 19 | [[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler 20 | -------------------------------------------------------------------------------- /model-index.yml: -------------------------------------------------------------------------------- 1 | Import: 2 | - ./docs/models/*.md 3 | Library: 4 | Name: PyTorch Image Models 5 | Headline: PyTorch image models, scripts, pretrained weights 6 | Website: https://rwightman.github.io/pytorch-image-models/ 7 | Repository: https://github.com/rwightman/pytorch-image-models 8 | Docs: https://rwightman.github.io/pytorch-image-models/ 9 | README: "# PyTorch Image Models\r\n\r\nPyTorch Image Models (TIMM) is a library\ 10 | \ for state-of-the-art image classification. With this library you can:\r\n\r\n\ 11 | - Choose from 300+ pre-trained state-of-the-art image classification models.\r\ 12 | \n- Train models afresh on research datasets such as ImageNet using provided scripts.\r\ 13 | \n- Finetune pre-trained models on your own datasets, including the latest cutting\ 14 | \ edge models." 15 | -------------------------------------------------------------------------------- /timm/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .checkpoint_saver import CheckpointSaver 3 | from .clip_grad import dispatch_clip_grad 4 | from .cuda import ApexScaler, NativeScaler 5 | from .decay_batch import decay_batch_step, check_batch_size_retry 6 | from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ 7 | world_info_from_env, is_distributed_env, is_primary 8 | from .jit import set_jit_legacy, set_jit_fuser 9 | from .log import setup_default_logging, FormatterNoInfo 10 | from .metrics import AverageMeter, accuracy 11 | from .misc import natural_key, add_bool_arg, ParseKwargs 12 | from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model 13 | from .model_ema import ModelEma, ModelEmaV2 14 | from .random import random_seed 15 | from .summary import update_summary, get_outdir 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. Hparam requests, training help are not feature requests. 4 | The discussion forum is available for asking questions or seeking help from the community. 5 | title: "[FEATURE] Feature title..." 6 | labels: enhancement 7 | assignees: '' 8 | 9 | --- 10 | 11 | **Is your feature request related to a problem? Please describe.** 12 | A clear and concise description of what the problem is. 13 | 14 | **Describe the solution you'd like** 15 | A clear and concise description of what you want to happen. 16 | 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | 20 | **Additional context** 21 | Add any other context or screenshots about the feature request here. 22 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config, resolve_model_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .dataset_info import DatasetInfo, CustomDatasetInfo 8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset 9 | from .loader import create_loader 10 | from .mixup import Mixup, FastCollateMixup 11 | from .readers import create_reader 12 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 13 | from .real_labels import RealLabelsImagenet 14 | from .transforms import * 15 | from .transforms_factory import create_transform 16 | -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /hfdocs/source/reference/optimizers.mdx: -------------------------------------------------------------------------------- 1 | # Optimization 2 | 3 | This page contains the API reference documentation for learning rate optimizers included in `timm`. 4 | 5 | ## Optimizers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.optim.optim_factory.create_optimizer 10 | [[autodoc]] timm.optim.optim_factory.create_optimizer_v2 11 | 12 | ### Optimizer Classes 13 | 14 | [[autodoc]] timm.optim.adabelief.AdaBelief 15 | [[autodoc]] timm.optim.adafactor.Adafactor 16 | [[autodoc]] timm.optim.adahessian.Adahessian 17 | [[autodoc]] timm.optim.adamp.AdamP 18 | [[autodoc]] timm.optim.adamw.AdamW 19 | [[autodoc]] timm.optim.lamb.Lamb 20 | [[autodoc]] timm.optim.lars.Lars 21 | [[autodoc]] timm.optim.lookahead.Lookahead 22 | [[autodoc]] timm.optim.madgrad.MADGRAD 23 | [[autodoc]] timm.optim.nadam.Nadam 24 | [[autodoc]] timm.optim.nvnovograd.NvNovoGrad 25 | [[autodoc]] timm.optim.radam.RAdam 26 | [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF 27 | [[autodoc]] timm.optim.sgdp.SGDP 28 | -------------------------------------------------------------------------------- /timm/data/readers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | def load_class_map(map_or_filename, root=''): 6 | if isinstance(map_or_filename, dict): 7 | assert dict, 'class_map dict must be non-empty' 8 | return map_or_filename 9 | class_map_path = map_or_filename 10 | if not os.path.exists(class_map_path): 11 | class_map_path = os.path.join(root, class_map_path) 12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 14 | if class_map_ext == '.txt': 15 | with open(class_map_path) as f: 16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 17 | elif class_map_ext == '.pkl': 18 | with open(class_map_path, 'rb') as f: 19 | class_to_idx = pickle.load(f) 20 | else: 21 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 22 | return class_to_idx 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting 4 | features, the discussion forum is available for asking questions or seeking help 5 | from the community. 6 | title: "[BUG] Issue title..." 7 | labels: bug 8 | assignees: rwightman 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. 18 | 2. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Windows 10, Ubuntu 18.04] 28 | - This repository version [e.g. pip 0.3.1 or commit ref] 29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import argparse 6 | import ast 7 | import re 8 | 9 | 10 | def natural_key(string_): 11 | """See http://www.codinghorror.com/blog/archives/001018.html""" 12 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 13 | 14 | 15 | def add_bool_arg(parser, name, default=False, help=''): 16 | dest_name = name.replace('-', '_') 17 | group = parser.add_mutually_exclusive_group(required=False) 18 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 19 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 20 | parser.set_defaults(**{dest_name: default}) 21 | 22 | 23 | class ParseKwargs(argparse.Action): 24 | def __call__(self, parser, namespace, values, option_string=None): 25 | kw = {} 26 | for value in values: 27 | key, value = value.split('=') 28 | try: 29 | kw[key] = ast.literal_eval(value) 30 | except ValueError: 31 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 32 | setattr(namespace, self.dest, kw) 33 | -------------------------------------------------------------------------------- /timm/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return tuple(x) 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pads a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/layers/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | class Format(str, Enum): 8 | NCHW = 'NCHW' 9 | NHWC = 'NHWC' 10 | NCL = 'NCL' 11 | NLC = 'NLC' 12 | 13 | 14 | FormatT = Union[str, Format] 15 | 16 | 17 | def get_spatial_dim(fmt: FormatT): 18 | fmt = Format(fmt) 19 | if fmt is Format.NLC: 20 | dim = (1,) 21 | elif fmt is Format.NCL: 22 | dim = (2,) 23 | elif fmt is Format.NHWC: 24 | dim = (1, 2) 25 | else: 26 | dim = (2, 3) 27 | return dim 28 | 29 | 30 | def get_channel_dim(fmt: FormatT): 31 | fmt = Format(fmt) 32 | if fmt is Format.NHWC: 33 | dim = 3 34 | elif fmt is Format.NLC: 35 | dim = 2 36 | else: 37 | dim = 1 38 | return dim 39 | 40 | 41 | def nchw_to(x: torch.Tensor, fmt: Format): 42 | if fmt == Format.NHWC: 43 | x = x.permute(0, 2, 3, 1) 44 | elif fmt == Format.NLC: 45 | x = x.flatten(2).transpose(1, 2) 46 | elif fmt == Format.NCL: 47 | x = x.flatten(2) 48 | return x 49 | 50 | 51 | def nhwc_to(x: torch.Tensor, fmt: Format): 52 | if fmt == Format.NCHW: 53 | x = x.permute(0, 3, 1, 2) 54 | elif fmt == Format.NLC: 55 | x = x.flatten(1, 2) 56 | elif fmt == Format.NCL: 57 | x = x.flatten(1, 2).transpose(1, 2) 58 | return x 59 | -------------------------------------------------------------------------------- /timm/layers/grn.py: -------------------------------------------------------------------------------- 1 | """ Global Response Normalization Module 2 | 3 | Based on the GRN layer presented in 4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 5 | 6 | This implementation 7 | * works for both NCHW and NHWC tensor layouts 8 | * uses affine param names matching existing torch norm layers 9 | * slightly improves eager mode performance via fused addcmul 10 | 11 | Hacked together by / Copyright 2023 Ross Wightman 12 | """ 13 | 14 | import torch 15 | from torch import nn as nn 16 | 17 | 18 | class GlobalResponseNorm(nn.Module): 19 | """ Global Response Normalization layer 20 | """ 21 | def __init__(self, dim, eps=1e-6, channels_last=True): 22 | super().__init__() 23 | self.eps = eps 24 | if channels_last: 25 | self.spatial_dim = (1, 2) 26 | self.channel_dim = -1 27 | self.wb_shape = (1, 1, 1, -1) 28 | else: 29 | self.spatial_dim = (2, 3) 30 | self.channel_dim = 1 31 | self.wb_shape = (1, -1, 1, 1) 32 | 33 | self.weight = nn.Parameter(torch.zeros(dim)) 34 | self.bias = nn.Parameter(torch.zeros(dim)) 35 | 36 | def forward(self, x): 37 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) 38 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) 39 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n) 40 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | 14 | def get_outdir(path, *paths, inc=False): 15 | outdir = os.path.join(path, *paths) 16 | if not os.path.exists(outdir): 17 | os.makedirs(outdir) 18 | elif inc: 19 | count = 1 20 | outdir_inc = outdir + '-' + str(count) 21 | while os.path.exists(outdir_inc): 22 | count = count + 1 23 | outdir_inc = outdir + '-' + str(count) 24 | assert count < 100 25 | outdir = outdir_inc 26 | os.makedirs(outdir) 27 | return outdir 28 | 29 | 30 | def update_summary( 31 | epoch, 32 | train_metrics, 33 | eval_metrics, 34 | filename, 35 | lr=None, 36 | write_header=False, 37 | log_wandb=False, 38 | ): 39 | rowd = OrderedDict(epoch=epoch) 40 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 41 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 42 | if lr is not None: 43 | rowd['lr'] = lr 44 | if log_wandb: 45 | wandb.log(rowd) 46 | with open(filename, mode='a') as cf: 47 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 48 | if write_header: # first iteration (epoch == 1 can't be used) 49 | dw.writeheader() 50 | dw.writerow(rowd) 51 | -------------------------------------------------------------------------------- /timm/data/readers/reader_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .reader_image_folder import ReaderImageFolder 4 | from .reader_image_in_tar import ReaderImageInTar 5 | 6 | 7 | def create_reader(name, root, split='train', **kwargs): 8 | name = name.lower() 9 | name = name.split('/', 1) 10 | prefix = '' 11 | if len(name) > 1: 12 | prefix = name[0] 13 | name = name[-1] 14 | 15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 16 | # explicitly select other options shortly 17 | if prefix == 'hfds': 18 | from .reader_hfds import ReaderHfds # defer tensorflow import 19 | reader = ReaderHfds(root, name, split=split, **kwargs) 20 | elif prefix == 'tfds': 21 | from .reader_tfds import ReaderTfds # defer tensorflow import 22 | reader = ReaderTfds(root, name, split=split, **kwargs) 23 | elif prefix == 'wds': 24 | from .reader_wds import ReaderWds 25 | kwargs.pop('download', False) 26 | reader = ReaderWds(root, name, split=split, **kwargs) 27 | else: 28 | assert os.path.exists(root) 29 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 30 | # FIXME support split here or in reader? 31 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 32 | reader = ReaderImageInTar(root, **kwargs) 33 | else: 34 | reader = ReaderImageFolder(root, **kwargs) 35 | return reader 36 | -------------------------------------------------------------------------------- /hfdocs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # timm 2 | 3 | 4 | 5 | `timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts. 6 | 7 | It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use. 8 | 9 | Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library. 10 | 11 |
12 |
13 |
Tutorials
15 |

Learn the basics and become familiar with timm. Start here if you are using timm for the first time!

16 |
17 |
Reference
19 |

Technical descriptions of how timm classes and methods work.

20 |
21 |
22 |
23 | -------------------------------------------------------------------------------- /timm/data/readers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /timm/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import functools 8 | import types 9 | from typing import Type 10 | 11 | import torch.nn as nn 12 | 13 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm 14 | from torchvision.ops.misc import FrozenBatchNorm2d 15 | 16 | _NORM_MAP = dict( 17 | batchnorm=nn.BatchNorm2d, 18 | batchnorm2d=nn.BatchNorm2d, 19 | batchnorm1d=nn.BatchNorm1d, 20 | groupnorm=GroupNorm, 21 | groupnorm1=GroupNorm1, 22 | layernorm=LayerNorm, 23 | layernorm2d=LayerNorm2d, 24 | rmsnorm=RmsNorm, 25 | frozenbatchnorm2d=FrozenBatchNorm2d, 26 | ) 27 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 28 | 29 | 30 | def create_norm_layer(layer_name, num_features, **kwargs): 31 | layer = get_norm_layer(layer_name) 32 | layer_instance = layer(num_features, **kwargs) 33 | return layer_instance 34 | 35 | 36 | def get_norm_layer(norm_layer): 37 | if norm_layer is None: 38 | return None 39 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 40 | norm_kwargs = {} 41 | 42 | # unbind partial fn, so args can be rebound later 43 | if isinstance(norm_layer, functools.partial): 44 | norm_kwargs.update(norm_layer.keywords) 45 | norm_layer = norm_layer.func 46 | 47 | if isinstance(norm_layer, str): 48 | if not norm_layer: 49 | return None 50 | layer_name = norm_layer.replace('_', '') 51 | norm_layer = _NORM_MAP[layer_name] 52 | else: 53 | norm_layer = norm_layer 54 | 55 | if norm_kwargs: 56 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 57 | return norm_layer 58 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: 'Pytorch Image Models' 2 | site_description: 'Pretained Image Recognition Models' 3 | repo_name: 'rwightman/pytorch-image-models' 4 | repo_url: 'https://github.com/rwightman/pytorch-image-models' 5 | nav: 6 | - index.md 7 | - models.md 8 | - ... | models/*.md 9 | - results.md 10 | - scripts.md 11 | - training_hparam_examples.md 12 | - feature_extraction.md 13 | - changes.md 14 | - archived_changes.md 15 | theme: 16 | name: 'material' 17 | feature: 18 | tabs: false 19 | extra_javascript: 20 | - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML' 21 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js 22 | - javascripts/tables.js 23 | markdown_extensions: 24 | - codehilite: 25 | linenums: true 26 | - admonition 27 | - pymdownx.arithmatex 28 | - pymdownx.betterem: 29 | smart_enable: all 30 | - pymdownx.caret 31 | - pymdownx.critic 32 | - pymdownx.details 33 | - pymdownx.emoji: 34 | emoji_generator: !!python/name:pymdownx.emoji.to_svg 35 | - pymdownx.inlinehilite 36 | - pymdownx.magiclink 37 | - pymdownx.mark 38 | - pymdownx.smartsymbols 39 | - pymdownx.superfences 40 | - pymdownx.tasklist: 41 | custom_checkbox: true 42 | - pymdownx.tilde 43 | - mdx_truly_sane_lists 44 | plugins: 45 | - search 46 | - awesome-pages 47 | - redirects: 48 | redirect_maps: 49 | 'index.md': 'https://huggingface.co/docs/timm/index' 50 | 'models.md': 'https://huggingface.co/docs/timm/models' 51 | 'results.md': 'https://huggingface.co/docs/timm/results' 52 | 'scripts.md': 'https://huggingface.co/docs/timm/training_script' 53 | 'training_hparam_examples.md': 'https://huggingface.co/docs/timm/training_script#training-examples' 54 | 'feature_extraction.md': 'https://huggingface.co/docs/timm/feature_extraction' 55 | -------------------------------------------------------------------------------- /timm/utils/decay_batch.py: -------------------------------------------------------------------------------- 1 | """ Batch size decay and retry helpers. 2 | 3 | Copyright 2022 Ross Wightman 4 | """ 5 | import math 6 | 7 | 8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): 9 | """ power of two batch-size decay with intra steps 10 | 11 | Decay by stepping between powers of 2: 12 | * determine power-of-2 floor of current batch size (base batch size) 13 | * divide above value by num_intra_steps to determine step size 14 | * floor batch_size to nearest multiple of step_size (from base batch size) 15 | Examples: 16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 20 | """ 21 | if batch_size <= 1: 22 | # return 0 for stopping value so easy to use in loop 23 | return 0 24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) 25 | step_size = max(base_batch_size // num_intra_steps, 1) 26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size 27 | if no_odd and batch_size % 2: 28 | batch_size -= 1 29 | return batch_size 30 | 31 | 32 | def check_batch_size_retry(error_str): 33 | """ check failure error string for conditions where batch decay retry should not be attempted 34 | """ 35 | error_str = error_str.lower() 36 | if 'required rank' in error_str: 37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's 38 | # not compatible with channels_last memory format. 39 | return False 40 | if 'illegal' in error_str: 41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state 42 | return False 43 | return True 44 | -------------------------------------------------------------------------------- /docs/scripts.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release. 3 | 4 | The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on 5 | [NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples). 6 | 7 | ## Training Script 8 | 9 | The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder. 10 | 11 | To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value: 12 | 13 | `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` 14 | 15 | NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. 16 | 17 | ## Validation / Inference Scripts 18 | 19 | Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script. 20 | 21 | To validate with the model's pretrained weights (if they exist): 22 | 23 | `python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained` 24 | 25 | To run inference from a checkpoint: 26 | 27 | `python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar` -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /timm/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /timm/layers/patch_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PatchDropout(nn.Module): 8 | """ 9 | https://arxiv.org/abs/2212.00794 10 | """ 11 | return_indices: torch.jit.Final[bool] 12 | 13 | def __init__( 14 | self, 15 | prob: float = 0.5, 16 | num_prefix_tokens: int = 1, 17 | ordered: bool = False, 18 | return_indices: bool = False, 19 | ): 20 | super().__init__() 21 | assert 0 <= prob < 1. 22 | self.prob = prob 23 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) 24 | self.ordered = ordered 25 | self.return_indices = return_indices 26 | 27 | def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: 28 | if not self.training or self.prob == 0.: 29 | if self.return_indices: 30 | return x, None 31 | return x 32 | 33 | if self.num_prefix_tokens: 34 | prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] 35 | else: 36 | prefix_tokens = None 37 | 38 | B = x.shape[0] 39 | L = x.shape[1] 40 | num_keep = max(1, int(L * (1. - self.prob))) 41 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] 42 | if self.ordered: 43 | # NOTE does not need to maintain patch order in typical transformer use, 44 | # but possibly useful for debug / visualization 45 | keep_indices = keep_indices.sort(dim=-1)[0] 46 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) 47 | 48 | if prefix_tokens is not None: 49 | x = torch.cat((prefix_tokens, x), dim=1) 50 | 51 | if self.return_indices: 52 | return x, keep_indices 53 | return x 54 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | import pkgutil 11 | 12 | 13 | class RealLabelsImagenet: 14 | 15 | def __init__(self, filenames, real_json=None, topk=(1, 5)): 16 | if real_json is not None: 17 | with open(real_json) as real_labels: 18 | real_labels = json.load(real_labels) 19 | else: 20 | real_labels = json.loads( 21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8')) 22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 23 | self.real_labels = real_labels 24 | self.filenames = filenames 25 | assert len(self.filenames) == len(self.real_labels) 26 | self.topk = topk 27 | self.is_correct = {k: [] for k in topk} 28 | self.sample_idx = 0 29 | 30 | def add_result(self, output): 31 | maxk = max(self.topk) 32 | _, pred_batch = output.topk(maxk, 1, True, True) 33 | pred_batch = pred_batch.cpu().numpy() 34 | for pred in pred_batch: 35 | filename = self.filenames[self.sample_idx] 36 | filename = os.path.basename(filename) 37 | if self.real_labels[filename]: 38 | for k in self.topk: 39 | self.is_correct[k].append( 40 | any([p in self.real_labels[filename] for p in pred[:k]])) 41 | self.sample_idx += 1 42 | 43 | def get_accuracy(self, k=None): 44 | if k is None: 45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 46 | else: 47 | return float(np.mean(self.is_correct[k])) * 100 48 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__( 18 | self, 19 | optimizer: torch.optim.Optimizer, 20 | decay_t: float, 21 | decay_rate: float = 1., 22 | warmup_t=0, 23 | warmup_lr_init=0, 24 | warmup_prefix=True, 25 | t_in_epochs=True, 26 | noise_range_t=None, 27 | noise_pct=0.67, 28 | noise_std=1.0, 29 | noise_seed=42, 30 | initialize=True, 31 | ) -> None: 32 | super().__init__( 33 | optimizer, 34 | param_group_field="lr", 35 | t_in_epochs=t_in_epochs, 36 | noise_range_t=noise_range_t, 37 | noise_pct=noise_pct, 38 | noise_std=noise_std, 39 | noise_seed=noise_seed, 40 | initialize=initialize, 41 | ) 42 | 43 | self.decay_t = decay_t 44 | self.decay_rate = decay_rate 45 | self.warmup_t = warmup_t 46 | self.warmup_lr_init = warmup_lr_init 47 | self.warmup_prefix = warmup_prefix 48 | if self.warmup_t: 49 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 50 | super().update_groups(self.warmup_lr_init) 51 | else: 52 | self.warmup_steps = [1 for _ in self.base_values] 53 | 54 | def _get_lr(self, t): 55 | if t < self.warmup_t: 56 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 57 | else: 58 | if self.warmup_prefix: 59 | t = t - self.warmup_t 60 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 61 | return lrs 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | output/ 104 | 105 | # PyTorch weights 106 | *.tar 107 | *.pth 108 | *.pt 109 | *.torch 110 | *.gz 111 | Untitled.ipynb 112 | Testing notebook.ipynb 113 | 114 | # Root dir exclusions 115 | /*.csv 116 | /*.yaml 117 | /*.json 118 | /*.jpg 119 | /*.png 120 | /*.zip 121 | /*.tar.* -------------------------------------------------------------------------------- /timm/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | bs: torch.jit.Final[int] 7 | 8 | def __init__(self, block_size=4): 9 | super().__init__() 10 | assert block_size == 4 11 | self.bs = block_size 12 | 13 | def forward(self, x): 14 | N, C, H, W = x.size() 15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 18 | return x 19 | 20 | 21 | @torch.jit.script 22 | class SpaceToDepthJit: 23 | def __call__(self, x: torch.Tensor): 24 | # assuming hard-coded that block_size==4 for acceleration 25 | N, C, H, W = x.size() 26 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 27 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 28 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 29 | return x 30 | 31 | 32 | class SpaceToDepthModule(nn.Module): 33 | def __init__(self, no_jit=False): 34 | super().__init__() 35 | if not no_jit: 36 | self.op = SpaceToDepthJit() 37 | else: 38 | self.op = SpaceToDepth() 39 | 40 | def forward(self, x): 41 | return self.op(x) 42 | 43 | 44 | class DepthToSpace(nn.Module): 45 | 46 | def __init__(self, block_size): 47 | super().__init__() 48 | self.bs = block_size 49 | 50 | def forward(self, x): 51 | N, C, H, W = x.size() 52 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 53 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 54 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 55 | return x 56 | -------------------------------------------------------------------------------- /timm/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('timm/version.py').read()) 14 | setup( 15 | name='timm', 16 | version=__version__, 17 | description='PyTorch Image Models', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/huggingface/pytorch-image-models', 21 | author='Ross Wightman', 22 | author_email='ross@huggingface.co', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.8', 33 | 'Programming Language :: Python :: 3.9', 34 | 'Programming Language :: Python :: 3.10', 35 | 'Programming Language :: Python :: 3.11', 36 | 'Topic :: Scientific/Engineering', 37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 38 | 'Topic :: Software Development', 39 | 'Topic :: Software Development :: Libraries', 40 | 'Topic :: Software Development :: Libraries :: Python Modules', 41 | ], 42 | 43 | # Note that this is a string of words separated by whitespace, not a list. 44 | keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', 45 | packages=find_packages(exclude=['convert', 'tests', 'results']), 46 | include_package_data=True, 47 | install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'], 48 | python_requires='>=3.7', 49 | ) 50 | 51 | -------------------------------------------------------------------------------- /hfdocs/source/hf_hub.mdx: -------------------------------------------------------------------------------- 1 | # Sharing and Loading Models From the Hugging Face Hub 2 | 3 | The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub. 4 | 5 | In this short guide, we'll see how to: 6 | 1. Share a `timm` model on the Hub 7 | 2. How to load that model back from the Hub 8 | 9 | ## Authenticating 10 | 11 | First, you'll need to make sure you have the `huggingface_hub` package installed. 12 | 13 | ```bash 14 | pip install huggingface_hub 15 | ``` 16 | 17 | Then, you'll need to authenticate yourself. You can do this by running the following command: 18 | 19 | ```bash 20 | huggingface-cli login 21 | ``` 22 | 23 | Or, if you're using a notebook, you can use the `notebook_login` helper: 24 | 25 | ```py 26 | >>> from huggingface_hub import notebook_login 27 | >>> notebook_login() 28 | ``` 29 | 30 | ## Sharing a Model 31 | 32 | ```py 33 | >>> import timm 34 | >>> model = timm.create_model('resnet18', pretrained=True, num_classes=4) 35 | ``` 36 | 37 | Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial. 38 | 39 | Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function. 40 | 41 | ```py 42 | >>> model_cfg = dict(labels=['a', 'b', 'c', 'd']) 43 | >>> timm.models.hub.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg) 44 | ``` 45 | 46 | Running the above would push the model to `/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code! 47 | 48 | ## Loading a Model 49 | 50 | Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub. 51 | 52 | ```py 53 | >>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True) 54 | ``` 55 | -------------------------------------------------------------------------------- /docs/models/.templates/generate_readmes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run this script to generate the model-index files in `models` from the templates in `.templates/models`. 3 | """ 4 | 5 | import argparse 6 | from pathlib import Path 7 | 8 | from jinja2 import Environment, FileSystemLoader 9 | 10 | import modelindex 11 | 12 | 13 | def generate_readmes(templates_path: Path, dest_path: Path): 14 | """Add the code snippet template to the readmes""" 15 | readme_templates_path = templates_path / "models" 16 | code_template_path = templates_path / "code_snippets.md" 17 | 18 | env = Environment( 19 | loader=FileSystemLoader([readme_templates_path, readme_templates_path.parent]), 20 | ) 21 | 22 | for readme in readme_templates_path.iterdir(): 23 | if readme.suffix == ".md": 24 | template = env.get_template(readme.name) 25 | 26 | # get the first model_name for this model family 27 | mi = modelindex.load(str(readme)) 28 | model_name = mi.models[0].name 29 | 30 | full_content = template.render(model_name=model_name) 31 | 32 | # generate full_readme 33 | with open(dest_path / readme.name, "w") as f: 34 | f.write(full_content) 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser(description="Model index generation config") 39 | parser.add_argument( 40 | "-t", 41 | "--templates", 42 | default=Path(__file__).parent / ".templates", 43 | type=str, 44 | help="Location of the markdown templates", 45 | ) 46 | parser.add_argument( 47 | "-d", 48 | "--dest", 49 | default=Path(__file__).parent / "models", 50 | type=str, 51 | help="Destination folder that contains the generated model-index files.", 52 | ) 53 | args = parser.parse_args() 54 | templates_path = Path(args.templates) 55 | dest_readmes_path = Path(args.dest) 56 | 57 | generate_readmes( 58 | templates_path, 59 | dest_readmes_path, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /docs/models/.templates/models/spnasnet.md: -------------------------------------------------------------------------------- 1 | # SPNASNet 2 | 3 | **Single-Path NAS** is a novel differentiable NAS method for designing hardware-efficient ConvNets in less than 4 hours. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{stamoulis2019singlepath, 15 | title={Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours}, 16 | author={Dimitrios Stamoulis and Ruizhou Ding and Di Wang and Dimitrios Lymberopoulos and Bodhi Priyantha and Jie Liu and Diana Marculescu}, 17 | year={2019}, 18 | eprint={1904.02877}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.LG} 21 | } 22 | ``` 23 | 24 | 63 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.layers import create_act_layer, set_layer_config 5 | 6 | import importlib 7 | import os 8 | 9 | torch_backend = os.environ.get('TORCH_BACKEND') 10 | if torch_backend is not None: 11 | importlib.import_module(torch_backend) 12 | torch_device = os.environ.get('TORCH_DEVICE', 'cpu') 13 | 14 | class MLP(nn.Module): 15 | def __init__(self, act_layer="relu", inplace=True): 16 | super(MLP, self).__init__() 17 | self.fc1 = nn.Linear(1000, 100) 18 | self.act = create_act_layer(act_layer, inplace=inplace) 19 | self.fc2 = nn.Linear(100, 10) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | def _run_act_layer_grad(act_type, inplace=True): 29 | x = torch.rand(10, 1000) * 10 30 | m = MLP(act_layer=act_type, inplace=inplace) 31 | 32 | def _run(x, act_layer=''): 33 | if act_layer: 34 | # replace act layer if set 35 | m.act = create_act_layer(act_layer, inplace=inplace) 36 | out = m(x) 37 | l = (out - 0).pow(2).sum() 38 | return l 39 | 40 | x = x.to(device=torch_device) 41 | m.to(device=torch_device) 42 | 43 | out_me = _run(x) 44 | 45 | with set_layer_config(scriptable=True): 46 | out_jit = _run(x, act_type) 47 | 48 | assert torch.isclose(out_jit, out_me) 49 | 50 | with set_layer_config(no_jit=True): 51 | out_basic = _run(x, act_type) 52 | 53 | assert torch.isclose(out_basic, out_jit) 54 | 55 | 56 | def test_swish_grad(): 57 | for _ in range(100): 58 | _run_act_layer_grad('swish') 59 | 60 | 61 | def test_mish_grad(): 62 | for _ in range(100): 63 | _run_act_layer_grad('mish') 64 | 65 | 66 | def test_hard_sigmoid_grad(): 67 | for _ in range(100): 68 | _run_act_layer_grad('hard_sigmoid', inplace=None) 69 | 70 | 71 | def test_hard_swish_grad(): 72 | for _ in range(100): 73 | _run_act_layer_grad('hard_swish') 74 | 75 | 76 | def test_hard_mish_grad(): 77 | for _ in range(100): 78 | _run_act_layer_grad('hard_mish') 79 | -------------------------------------------------------------------------------- /docs/models/.templates/models/gloun-senet.md: -------------------------------------------------------------------------------- 1 | # (Gluon) SENet 2 | 3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration. 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{hu2019squeezeandexcitation, 17 | title={Squeeze-and-Excitation Networks}, 18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu}, 19 | year={2019}, 20 | eprint={1709.01507}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 64 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.batchnorm import BatchNorm2d 2 | from torchvision.ops.misc import FrozenBatchNorm2d 3 | 4 | import timm 5 | from timm.utils.model import freeze, unfreeze 6 | 7 | 8 | def test_freeze_unfreeze(): 9 | model = timm.create_model('resnet18') 10 | 11 | # Freeze all 12 | freeze(model) 13 | # Check top level module 14 | assert model.fc.weight.requires_grad == False 15 | # Check submodule 16 | assert model.layer1[0].conv1.weight.requires_grad == False 17 | # Check BN 18 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 19 | 20 | # Unfreeze all 21 | unfreeze(model) 22 | # Check top level module 23 | assert model.fc.weight.requires_grad == True 24 | # Check submodule 25 | assert model.layer1[0].conv1.weight.requires_grad == True 26 | # Check BN 27 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 28 | 29 | # Freeze some 30 | freeze(model, ['layer1', 'layer2.0']) 31 | # Check frozen 32 | assert model.layer1[0].conv1.weight.requires_grad == False 33 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 34 | assert model.layer2[0].conv1.weight.requires_grad == False 35 | # Check not frozen 36 | assert model.layer3[0].conv1.weight.requires_grad == True 37 | assert isinstance(model.layer3[0].bn1, BatchNorm2d) 38 | assert model.layer2[1].conv1.weight.requires_grad == True 39 | 40 | # Unfreeze some 41 | unfreeze(model, ['layer1', 'layer2.0']) 42 | # Check not frozen 43 | assert model.layer1[0].conv1.weight.requires_grad == True 44 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 45 | assert model.layer2[0].conv1.weight.requires_grad == True 46 | 47 | # Freeze/unfreeze BN 48 | # From root 49 | freeze(model, ['layer1.0.bn1']) 50 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 51 | unfreeze(model, ['layer1.0.bn1']) 52 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) 53 | # From direct parent 54 | freeze(model.layer1[0], ['bn1']) 55 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) 56 | unfreeze(model.layer1[0], ['bn1']) 57 | assert isinstance(model.layer1[0].bn1, BatchNorm2d) -------------------------------------------------------------------------------- /timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | decay_t: List[int], 18 | decay_rate: float = 1., 19 | warmup_t=0, 20 | warmup_lr_init=0, 21 | warmup_prefix=True, 22 | t_in_epochs=True, 23 | noise_range_t=None, 24 | noise_pct=0.67, 25 | noise_std=1.0, 26 | noise_seed=42, 27 | initialize=True, 28 | ) -> None: 29 | super().__init__( 30 | optimizer, 31 | param_group_field="lr", 32 | t_in_epochs=t_in_epochs, 33 | noise_range_t=noise_range_t, 34 | noise_pct=noise_pct, 35 | noise_std=noise_std, 36 | noise_seed=noise_seed, 37 | initialize=initialize, 38 | ) 39 | 40 | self.decay_t = decay_t 41 | self.decay_rate = decay_rate 42 | self.warmup_t = warmup_t 43 | self.warmup_lr_init = warmup_lr_init 44 | self.warmup_prefix = warmup_prefix 45 | if self.warmup_t: 46 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 47 | super().update_groups(self.warmup_lr_init) 48 | else: 49 | self.warmup_steps = [1 for _ in self.base_values] 50 | 51 | def get_curr_decay_steps(self, t): 52 | # find where in the array t goes, 53 | # assumes self.decay_t is sorted 54 | return bisect.bisect_right(self.decay_t, t + 1) 55 | 56 | def _get_lr(self, t): 57 | if t < self.warmup_t: 58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 59 | else: 60 | if self.warmup_prefix: 61 | t = t - self.warmup_t 62 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 63 | return lrs 64 | -------------------------------------------------------------------------------- /timm/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, 18 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): 19 | super(BinaryCrossEntropy, self).__init__() 20 | assert 0. <= smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.target_threshold = target_threshold 23 | self.reduction = reduction 24 | self.register_buffer('weight', weight) 25 | self.register_buffer('pos_weight', pos_weight) 26 | 27 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 28 | assert x.shape[0] == target.shape[0] 29 | if target.shape != x.shape: 30 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 31 | num_classes = x.shape[-1] 32 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 33 | off_value = self.smoothing / num_classes 34 | on_value = 1. - self.smoothing + off_value 35 | target = target.long().view(-1, 1) 36 | target = torch.full( 37 | (target.size()[0], num_classes), 38 | off_value, 39 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 40 | if self.target_threshold is not None: 41 | # Make target 0, or 1 if threshold set 42 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 43 | return F.binary_cross_entropy_with_logits( 44 | x, target, 45 | self.weight, 46 | pos_weight=self.pos_weight, 47 | reduction=self.reduction) 48 | -------------------------------------------------------------------------------- /docs/models/.templates/models/nasnet.md: -------------------------------------------------------------------------------- 1 | # NASNet 2 | 3 | **NASNet** is a type of convolutional neural network discovered through neural architecture search. The building blocks consist of normal and reduction cells. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{zoph2018learning, 15 | title={Learning Transferable Architectures for Scalable Image Recognition}, 16 | author={Barret Zoph and Vijay Vasudevan and Jonathon Shlens and Quoc V. Le}, 17 | year={2018}, 18 | eprint={1707.07012}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 71 | -------------------------------------------------------------------------------- /docs/models/.templates/models/gloun-xception.md: -------------------------------------------------------------------------------- 1 | # (Gluon) Xception 2 | 3 | **Xception** is a convolutional neural network architecture that relies solely on [depthwise separable convolution](https://paperswithcode.com/method/depthwise-separable-convolution) layers. 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{chollet2017xception, 17 | title={Xception: Deep Learning with Depthwise Separable Convolutions}, 18 | author={François Chollet}, 19 | year={2017}, 20 | eprint={1610.02357}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 67 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__( 21 | self, 22 | loss, 23 | optimizer, 24 | clip_grad=None, 25 | clip_mode='norm', 26 | parameters=None, 27 | create_graph=False, 28 | need_update=True, 29 | ): 30 | with amp.scale_loss(loss, optimizer) as scaled_loss: 31 | scaled_loss.backward(create_graph=create_graph) 32 | if need_update: 33 | if clip_grad is not None: 34 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 35 | optimizer.step() 36 | 37 | def state_dict(self): 38 | if 'state_dict' in amp.__dict__: 39 | return amp.state_dict() 40 | 41 | def load_state_dict(self, state_dict): 42 | if 'load_state_dict' in amp.__dict__: 43 | amp.load_state_dict(state_dict) 44 | 45 | 46 | class NativeScaler: 47 | state_dict_key = "amp_scaler" 48 | 49 | def __init__(self): 50 | self._scaler = torch.cuda.amp.GradScaler() 51 | 52 | def __call__( 53 | self, 54 | loss, 55 | optimizer, 56 | clip_grad=None, 57 | clip_mode='norm', 58 | parameters=None, 59 | create_graph=False, 60 | need_update=True, 61 | ): 62 | self._scaler.scale(loss).backward(create_graph=create_graph) 63 | if need_update: 64 | if clip_grad is not None: 65 | assert parameters is not None 66 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 67 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 68 | self._scaler.step(optimizer) 69 | self._scaler.update() 70 | 71 | def state_dict(self): 72 | return self._scaler.state_dict() 73 | 74 | def load_state_dict(self, state_dict): 75 | self._scaler.load_state_dict(state_dict) 76 | -------------------------------------------------------------------------------- /hfdocs/source/installation.mdx: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**. 4 | 5 | ## Virtual Environment 6 | 7 | You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts. 8 | 9 | 1. Create and navigate to your project directory: 10 | 11 | ```bash 12 | mkdir ~/my-project 13 | cd ~/my-project 14 | ``` 15 | 16 | 2. Start a virtual environment inside your directory: 17 | 18 | ```bash 19 | python -m venv .env 20 | ``` 21 | 22 | 3. Activate and deactivate the virtual environment with the following commands: 23 | 24 | ```bash 25 | # Activate the virtual environment 26 | source .env/bin/activate 27 | 28 | # Deactivate the virtual environment 29 | source .env/bin/deactivate 30 | ``` 31 | ` 32 | Once you've created your virtual environment, you can install `timm` in it. 33 | 34 | ## Using pip 35 | 36 | The most straightforward way to install `timm` is with pip: 37 | 38 | ```bash 39 | pip install timm 40 | ``` 41 | 42 | Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version: 43 | 44 | ```bash 45 | pip install git+https://github.com/rwightman/pytorch-image-models.git 46 | ``` 47 | 48 | Run the following command to check if `timm` has been properly installed: 49 | 50 | ```bash 51 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 52 | ``` 53 | 54 | This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output: 55 | 56 | ```python 57 | ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384'] 58 | ``` 59 | 60 | ## From Source 61 | 62 | Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands: 63 | 64 | ```bash 65 | git clone https://github.com/rwightman/pytorch-image-models.git 66 | cd timm 67 | pip install -e . 68 | ``` 69 | 70 | Again, you can check if `timm` was properly installed with the following command: 71 | 72 | ```bash 73 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 74 | ``` 75 | -------------------------------------------------------------------------------- /docs/models/.templates/models/legacy-senet.md: -------------------------------------------------------------------------------- 1 | # (Legacy) SENet 2 | 3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration. 4 | 5 | The weights from this model were ported from Gluon. 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{hu2019squeezeandexcitation, 17 | title={Squeeze-and-Excitation Networks}, 18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu}, 19 | year={2019}, 20 | eprint={1709.01507}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 75 | -------------------------------------------------------------------------------- /docs/models/.templates/models/inception-v4.md: -------------------------------------------------------------------------------- 1 | # Inception v4 2 | 3 | **Inception-v4** is a convolutional neural network architecture that builds on previous iterations of the Inception family by simplifying the architecture and using more inception modules than [Inception-v3](https://paperswithcode.com/method/inception-v3). 4 | {% include 'code_snippets.md' %} 5 | 6 | ## How do I train this model? 7 | 8 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 9 | 10 | ## Citation 11 | 12 | ```BibTeX 13 | @misc{szegedy2016inceptionv4, 14 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning}, 15 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi}, 16 | year={2016}, 17 | eprint={1602.07261}, 18 | archivePrefix={arXiv}, 19 | primaryClass={cs.CV} 20 | } 21 | ``` 22 | 23 | 72 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | 10 | def set_jit_legacy(): 11 | """ Set JIT executor to legacy w/ support for op fusion 12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 13 | in the JIT exectutor. These API are not supported so could change. 14 | """ 15 | # 16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 17 | torch._C._jit_set_profiling_executor(False) 18 | torch._C._jit_set_profiling_mode(False) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | #torch._C._jit_set_texpr_fuser_enabled(True) 21 | 22 | 23 | def set_jit_fuser(fuser): 24 | if fuser == "te": 25 | # default fuser should be == 'te' 26 | torch._C._jit_set_profiling_executor(True) 27 | torch._C._jit_set_profiling_mode(True) 28 | torch._C._jit_override_can_fuse_on_cpu(False) 29 | torch._C._jit_override_can_fuse_on_gpu(True) 30 | torch._C._jit_set_texpr_fuser_enabled(True) 31 | try: 32 | torch._C._jit_set_nvfuser_enabled(False) 33 | except Exception: 34 | pass 35 | elif fuser == "old" or fuser == "legacy": 36 | torch._C._jit_set_profiling_executor(False) 37 | torch._C._jit_set_profiling_mode(False) 38 | torch._C._jit_override_can_fuse_on_gpu(True) 39 | torch._C._jit_set_texpr_fuser_enabled(False) 40 | try: 41 | torch._C._jit_set_nvfuser_enabled(False) 42 | except Exception: 43 | pass 44 | elif fuser == "nvfuser" or fuser == "nvf": 45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' 46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' 47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' 48 | torch._C._jit_set_texpr_fuser_enabled(False) 49 | torch._C._jit_set_profiling_executor(True) 50 | torch._C._jit_set_profiling_mode(True) 51 | torch._C._jit_can_fuse_on_cpu() 52 | torch._C._jit_can_fuse_on_gpu() 53 | torch._C._jit_override_can_fuse_on_cpu(False) 54 | torch._C._jit_override_can_fuse_on_gpu(False) 55 | torch._C._jit_set_nvfuser_guard_mode(True) 56 | torch._C._jit_set_nvfuser_enabled(True) 57 | else: 58 | assert False, f"Invalid jit fuser ({fuser})" 59 | -------------------------------------------------------------------------------- /docs/models/.templates/models/skresnext.md: -------------------------------------------------------------------------------- 1 | # SK-ResNeXt 2 | 3 | **SK ResNeXt** is a variant of a [ResNeXt](https://www.paperswithcode.com/method/resnext) that employs a [Selective Kernel](https://paperswithcode.com/method/selective-kernel) unit. In general, all the large kernel convolutions in the original bottleneck blocks in ResNext are replaced by the proposed [SK convolutions](https://paperswithcode.com/method/selective-kernel-convolution), enabling the network to choose appropriate receptive field sizes in an adaptive manner. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{li2019selective, 15 | title={Selective Kernel Networks}, 16 | author={Xiang Li and Wenhai Wang and Xiaolin Hu and Jian Yang}, 17 | year={2019}, 18 | eprint={1903.06586}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 71 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | 3 | ## Welcome 4 | 5 | Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`. 6 | 7 | For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora). 8 | 9 | ## Install 10 | 11 | The library can be installed with pip: 12 | 13 | ``` 14 | pip install timm 15 | ``` 16 | 17 | I update the PyPi (pip) packages when I'm confident there are no significant model regressions from previous releases. If you want to pip install the bleeding edge from GitHub, use: 18 | ``` 19 | pip install git+https://github.com/rwightman/pytorch-image-models.git 20 | ``` 21 | 22 | !!! info "Conda Environment" 23 | All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10 24 | 25 | Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. 26 | 27 | PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code. 28 | 29 | I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: 30 | ``` 31 | conda create -n torch-env 32 | conda activate torch-env 33 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch 34 | conda install pyyaml 35 | ``` 36 | 37 | ## Load a Pretrained Model 38 | 39 | Pretrained models can be loaded using `timm.create_model` 40 | 41 | ```python 42 | import timm 43 | 44 | m = timm.create_model('mobilenetv3_large_100', pretrained=True) 45 | m.eval() 46 | ``` 47 | 48 | ## List Models with Pretrained Weights 49 | ```python 50 | import timm 51 | from pprint import pprint 52 | model_names = timm.list_models(pretrained=True) 53 | pprint(model_names) 54 | >>> ['adv_inception_v3', 55 | 'cspdarknet53', 56 | 'cspresnext50', 57 | 'densenet121', 58 | 'densenet161', 59 | 'densenet169', 60 | 'densenet201', 61 | 'densenetblur121d', 62 | 'dla34', 63 | 'dla46_c', 64 | ... 65 | ] 66 | ``` 67 | 68 | ## List Model Architectures by Wildcard 69 | ```python 70 | import timm 71 | from pprint import pprint 72 | model_names = timm.list_models('*resne*t*') 73 | pprint(model_names) 74 | >>> ['cspresnet50', 75 | 'cspresnet50d', 76 | 'cspresnet50w', 77 | 'cspresnext50', 78 | ... 79 | ] 80 | ``` 81 | -------------------------------------------------------------------------------- /docs/models/.templates/models/pnasnet.md: -------------------------------------------------------------------------------- 1 | # PNASNet 2 | 3 | **Progressive Neural Architecture Search**, or **PNAS**, is a method for learning the structure of convolutional neural networks (CNNs). It uses a sequential model-based optimization (SMBO) strategy, where we search the space of cell structures, starting with simple (shallow) models and progressing to complex ones, pruning out unpromising structures as we go. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{liu2018progressive, 15 | title={Progressive Neural Architecture Search}, 16 | author={Chenxi Liu and Barret Zoph and Maxim Neumann and Jonathon Shlens and Wei Hua and Li-Jia Li and Li Fei-Fei and Alan Yuille and Jonathan Huang and Kevin Murphy}, 17 | year={2018}, 18 | eprint={1712.00559}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 72 | -------------------------------------------------------------------------------- /docs/models/.templates/models/inception-resnet-v2.md: -------------------------------------------------------------------------------- 1 | # Inception ResNet v2 2 | 3 | **Inception-ResNet-v2** is a convolutional neural architecture that builds on the Inception family of architectures but incorporates [residual connections](https://paperswithcode.com/method/residual-connection) (replacing the filter concatenation stage of the Inception architecture). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{szegedy2016inceptionv4, 15 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning}, 16 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi}, 17 | year={2016}, 18 | eprint={1602.07261}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 73 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | OMP_NUM_THREADS: 2 11 | MKL_NUM_THREADS: 2 12 | 13 | jobs: 14 | test: 15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest] 19 | python: ['3.10', '3.11'] 20 | torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}] 21 | testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] 22 | exclude: 23 | - python: '3.11' 24 | torch: {base: '1.13.0', vision: '0.14.0'} 25 | runs-on: ${{ matrix.os }} 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | - name: Set up Python ${{ matrix.python }} 30 | uses: actions/setup-python@v1 31 | with: 32 | python-version: ${{ matrix.python }} 33 | - name: Install testing dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install -r requirements-dev.txt 37 | - name: Install torch on mac 38 | if: startsWith(matrix.os, 'macOS') 39 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 40 | - name: Install torch on Windows 41 | if: startsWith(matrix.os, 'windows') 42 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 43 | - name: Install torch on ubuntu 44 | if: startsWith(matrix.os, 'ubuntu') 45 | run: | 46 | sudo sed -i 's/azure\.//' /etc/apt/sources.list 47 | sudo apt update 48 | sudo apt install -y google-perftools 49 | pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html 50 | - name: Install requirements 51 | run: | 52 | pip install -r requirements.txt 53 | - name: Run tests on Windows 54 | if: startsWith(matrix.os, 'windows') 55 | env: 56 | PYTHONDONTWRITEBYTECODE: 1 57 | run: | 58 | pytest -vv tests 59 | - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac 60 | if: ${{ !startsWith(matrix.os, 'windows') }} 61 | env: 62 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 63 | PYTHONDONTWRITEBYTECODE: 1 64 | run: | 65 | pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests 66 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__(self, params, lr=required, momentum=0, dampening=0, 21 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): 22 | defaults = dict( 23 | lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, 24 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) 25 | super(SGDP, self).__init__(params, defaults) 26 | 27 | @torch.no_grad() 28 | def step(self, closure=None): 29 | loss = None 30 | if closure is not None: 31 | with torch.enable_grad(): 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | weight_decay = group['weight_decay'] 36 | momentum = group['momentum'] 37 | dampening = group['dampening'] 38 | nesterov = group['nesterov'] 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad 44 | state = self.state[p] 45 | 46 | # State initialization 47 | if len(state) == 0: 48 | state['momentum'] = torch.zeros_like(p) 49 | 50 | # SGD 51 | buf = state['momentum'] 52 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 53 | if nesterov: 54 | d_p = grad + momentum * buf 55 | else: 56 | d_p = buf 57 | 58 | # Projection 59 | wd_ratio = 1. 60 | if len(p.shape) > 1: 61 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 62 | 63 | # Weight decay 64 | if weight_decay != 0: 65 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 66 | 67 | # Step 68 | p.add_(d_p, alpha=-group['lr']) 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /docs/models/.templates/models/csp-resnet.md: -------------------------------------------------------------------------------- 1 | # CSP-ResNet 2 | 3 | **CSPResNet** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNet](https://paperswithcode.com/method/resnet). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{wang2019cspnet, 15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN}, 16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh}, 17 | year={2019}, 18 | eprint={1911.11929}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 77 | -------------------------------------------------------------------------------- /timm/data/dataset_info.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Union 3 | 4 | 5 | class DatasetInfo(ABC): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | @abstractmethod 11 | def num_classes(self): 12 | pass 13 | 14 | @abstractmethod 15 | def label_names(self): 16 | pass 17 | 18 | @abstractmethod 19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 20 | pass 21 | 22 | @abstractmethod 23 | def index_to_label_name(self, index) -> str: 24 | pass 25 | 26 | @abstractmethod 27 | def index_to_description(self, index: int, detailed: bool = False) -> str: 28 | pass 29 | 30 | @abstractmethod 31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 32 | pass 33 | 34 | 35 | class CustomDatasetInfo(DatasetInfo): 36 | """ DatasetInfo that wraps passed values for custom datasets.""" 37 | 38 | def __init__( 39 | self, 40 | label_names: Union[List[str], Dict[int, str]], 41 | label_descriptions: Optional[Dict[str, str]] = None 42 | ): 43 | super().__init__() 44 | assert len(label_names) > 0 45 | self._label_names = label_names # label index => label name mapping 46 | self._label_descriptions = label_descriptions # label name => label description mapping 47 | if self._label_descriptions is not None: 48 | # validate descriptions (label names required) 49 | assert isinstance(self._label_descriptions, dict) 50 | for n in self._label_names: 51 | assert n in self._label_descriptions 52 | 53 | def num_classes(self): 54 | return len(self._label_names) 55 | 56 | def label_names(self): 57 | return self._label_names 58 | 59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 60 | return self._label_descriptions 61 | 62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 63 | if self._label_descriptions: 64 | return self._label_descriptions[label] 65 | return label # return label name itself if a descriptions is not present 66 | 67 | def index_to_label_name(self, index) -> str: 68 | assert 0 <= index < len(self._label_names) 69 | return self._label_names[index] 70 | 71 | def index_to_description(self, index: int, detailed: bool = False) -> str: 72 | label = self.index_to_label_name(index) 73 | return self.label_name_to_description(label, detailed=detailed) 74 | -------------------------------------------------------------------------------- /timm/layers/interpolate.py: -------------------------------------------------------------------------------- 1 | """ Interpolation helpers for timm layers 2 | 3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations 4 | Copyright Shane Barratt, Apache 2.0 license 5 | """ 6 | import torch 7 | from itertools import product 8 | 9 | 10 | class RegularGridInterpolator: 11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing. 12 | Produces similar results to scipy RegularGridInterpolator or interp2d 13 | in 'linear' mode. 14 | 15 | Taken from https://github.com/sbarratt/torch_interpolations 16 | """ 17 | 18 | def __init__(self, points, values): 19 | self.points = points 20 | self.values = values 21 | 22 | assert isinstance(self.points, tuple) or isinstance(self.points, list) 23 | assert isinstance(self.values, torch.Tensor) 24 | 25 | self.ms = list(self.values.shape) 26 | self.n = len(self.points) 27 | 28 | assert len(self.ms) == self.n 29 | 30 | for i, p in enumerate(self.points): 31 | assert isinstance(p, torch.Tensor) 32 | assert p.shape[0] == self.values.shape[i] 33 | 34 | def __call__(self, points_to_interp): 35 | assert self.points is not None 36 | assert self.values is not None 37 | 38 | assert len(points_to_interp) == len(self.points) 39 | K = points_to_interp[0].shape[0] 40 | for x in points_to_interp: 41 | assert x.shape[0] == K 42 | 43 | idxs = [] 44 | dists = [] 45 | overalls = [] 46 | for p, x in zip(self.points, points_to_interp): 47 | idx_right = torch.bucketize(x, p) 48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) 50 | dist_left = x - p[idx_left] 51 | dist_right = p[idx_right] - x 52 | dist_left[dist_left < 0] = 0. 53 | dist_right[dist_right < 0] = 0. 54 | both_zero = (dist_left == 0) & (dist_right == 0) 55 | dist_left[both_zero] = dist_right[both_zero] = 1. 56 | 57 | idxs.append((idx_left, idx_right)) 58 | dists.append((dist_left, dist_right)) 59 | overalls.append(dist_left + dist_right) 60 | 61 | numerator = 0. 62 | for indexer in product([0, 1], repeat=self.n): 63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] 64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] 65 | numerator += self.values[as_s] * \ 66 | torch.prod(torch.stack(bs_s), dim=0) 67 | denominator = torch.prod(torch.stack(overalls), dim=0) 68 | return numerator / denominator 69 | -------------------------------------------------------------------------------- /docs/models/.templates/code_snippets.md: -------------------------------------------------------------------------------- 1 | ## How do I use this model on an image? 2 | To load a pretrained model: 3 | 4 | ```python 5 | import timm 6 | model = timm.create_model('{{ model_name }}', pretrained=True) 7 | model.eval() 8 | ``` 9 | 10 | To load and preprocess the image: 11 | ```python 12 | import urllib 13 | from PIL import Image 14 | from timm.data import resolve_data_config 15 | from timm.data.transforms_factory import create_transform 16 | 17 | config = resolve_data_config({}, model=model) 18 | transform = create_transform(**config) 19 | 20 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") 21 | urllib.request.urlretrieve(url, filename) 22 | img = Image.open(filename).convert('RGB') 23 | tensor = transform(img).unsqueeze(0) # transform and add batch dimension 24 | ``` 25 | 26 | To get the model predictions: 27 | ```python 28 | import torch 29 | with torch.no_grad(): 30 | out = model(tensor) 31 | probabilities = torch.nn.functional.softmax(out[0], dim=0) 32 | print(probabilities.shape) 33 | # prints: torch.Size([1000]) 34 | ``` 35 | 36 | To get the top-5 predictions class names: 37 | ```python 38 | # Get imagenet class mappings 39 | url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt") 40 | urllib.request.urlretrieve(url, filename) 41 | with open("imagenet_classes.txt", "r") as f: 42 | categories = [s.strip() for s in f.readlines()] 43 | 44 | # Print top categories per image 45 | top5_prob, top5_catid = torch.topk(probabilities, 5) 46 | for i in range(top5_prob.size(0)): 47 | print(categories[top5_catid[i]], top5_prob[i].item()) 48 | # prints class names and probabilities like: 49 | # [('Samoyed', 0.6425196528434753), ('Pomeranian', 0.04062102362513542), ('keeshond', 0.03186424449086189), ('white wolf', 0.01739676296710968), ('Eskimo dog', 0.011717947199940681)] 50 | ``` 51 | 52 | Replace the model name with the variant you want to use, e.g. `{{ model_name }}`. You can find the IDs in the model summaries at the top of this page. 53 | 54 | To extract image features with this model, follow the [timm feature extraction examples](https://rwightman.github.io/pytorch-image-models/feature_extraction/), just change the name of the model you want to use. 55 | 56 | ## How do I finetune this model? 57 | You can finetune any of the pre-trained models just by changing the classifier (the last layer). 58 | ```python 59 | model = timm.create_model('{{ model_name }}', pretrained=True, num_classes=NUM_FINETUNE_CLASSES) 60 | ``` 61 | To finetune on your own dataset, you have to write a training loop or adapt [timm's training 62 | script](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) to use your dataset. 63 | -------------------------------------------------------------------------------- /timm/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /docs/models/.templates/models/csp-resnext.md: -------------------------------------------------------------------------------- 1 | # CSP-ResNeXt 2 | 3 | **CSPResNeXt** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNeXt](https://paperswithcode.com/method/resnext). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @misc{wang2019cspnet, 15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN}, 16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh}, 17 | year={2019}, 18 | eprint={1911.11929}, 19 | archivePrefix={arXiv}, 20 | primaryClass={cs.CV} 21 | } 22 | ``` 23 | 24 | 78 | -------------------------------------------------------------------------------- /docs/models/.templates/models/fbnet.md: -------------------------------------------------------------------------------- 1 | # FBNet 2 | 3 | **FBNet** is a type of convolutional neural architectures discovered through [DNAS](https://paperswithcode.com/method/dnas) neural architecture search. It utilises a basic type of image model block inspired by [MobileNetv2](https://paperswithcode.com/method/mobilenetv2) that utilises depthwise convolutions and an inverted residual structure (see components). 4 | 5 | The principal building block is the [FBNet Block](https://paperswithcode.com/method/fbnet-block). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{wu2019fbnet, 17 | title={FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search}, 18 | author={Bichen Wu and Xiaoliang Dai and Peizhao Zhang and Yanghan Wang and Fei Sun and Yiming Wu and Yuandong Tian and Peter Vajda and Yangqing Jia and Kurt Keutzer}, 19 | year={2019}, 20 | eprint={1812.03443}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 77 | -------------------------------------------------------------------------------- /timm/data/readers/reader_hfds.py: -------------------------------------------------------------------------------- 1 | """ Dataset reader that wraps Hugging Face datasets 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import io 6 | import math 7 | import torch 8 | import torch.distributed as dist 9 | from PIL import Image 10 | 11 | try: 12 | import datasets 13 | except ImportError as e: 14 | print("Please install Hugging Face datasets package `pip install datasets`.") 15 | exit(1) 16 | from .class_map import load_class_map 17 | from .reader import Reader 18 | 19 | 20 | def get_class_labels(info, label_key='label'): 21 | if 'label' not in info.features: 22 | return {} 23 | class_label = info.features[label_key] 24 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names} 25 | return class_to_idx 26 | 27 | 28 | class ReaderHfds(Reader): 29 | 30 | def __init__( 31 | self, 32 | root, 33 | name, 34 | split='train', 35 | class_map=None, 36 | label_key='label', 37 | download=False, 38 | ): 39 | """ 40 | """ 41 | super().__init__() 42 | self.root = root 43 | self.split = split 44 | self.dataset = datasets.load_dataset( 45 | name, # 'name' maps to path arg in hf datasets 46 | split=split, 47 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path 48 | ) 49 | # leave decode for caller, plus we want easy access to original path names... 50 | self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) 51 | 52 | self.label_key = label_key 53 | self.remap_class = False 54 | if class_map: 55 | self.class_to_idx = load_class_map(class_map) 56 | self.remap_class = True 57 | else: 58 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) 59 | self.split_info = self.dataset.info.splits[split] 60 | self.num_samples = self.split_info.num_examples 61 | 62 | def __getitem__(self, index): 63 | item = self.dataset[index] 64 | image = item['image'] 65 | if 'bytes' in image and image['bytes']: 66 | image = io.BytesIO(image['bytes']) 67 | else: 68 | assert 'path' in image and image['path'] 69 | image = open(image['path'], 'rb') 70 | label = item[self.label_key] 71 | if self.remap_class: 72 | label = self.class_to_idx[label] 73 | return image, label 74 | 75 | def __len__(self): 76 | return len(self.dataset) 77 | 78 | def _filename(self, index, basename=False, absolute=False): 79 | item = self.dataset[index] 80 | return item['image']['path'] 81 | -------------------------------------------------------------------------------- /docs/models/.templates/models/res2next.md: -------------------------------------------------------------------------------- 1 | # Res2NeXt 2 | 3 | **Res2NeXt** is an image model that employs a variation on [ResNeXt](https://paperswithcode.com/method/resnext) bottleneck residual blocks. The motivation is to be able to represent features at multiple scales. This is achieved through a novel building block for CNNs that constructs hierarchical residual-like connections within one single residual block. This represents multi-scale features at a granular level and increases the range of receptive fields for each network layer. 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{Gao_2021, 15 | title={Res2Net: A New Multi-Scale Backbone Architecture}, 16 | volume={43}, 17 | ISSN={1939-3539}, 18 | url={http://dx.doi.org/10.1109/TPAMI.2019.2938758}, 19 | DOI={10.1109/tpami.2019.2938758}, 20 | number={2}, 21 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 22 | publisher={Institute of Electrical and Electronics Engineers (IEEE)}, 23 | author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip}, 24 | year={2021}, 25 | month={Feb}, 26 | pages={652–662} 27 | } 28 | ``` 29 | 30 | 76 | -------------------------------------------------------------------------------- /timm/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /docs/models/.templates/models/csp-darknet.md: -------------------------------------------------------------------------------- 1 | # CSP-DarkNet 2 | 3 | **CSPDarknet53** is a convolutional neural network and backbone for object detection that uses [DarkNet-53](https://paperswithcode.com/method/darknet-53). It employs a CSPNet strategy to partition the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network. 4 | 5 | This CNN is used as the backbone for [YOLOv4](https://paperswithcode.com/method/yolov4). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{bochkovskiy2020yolov4, 17 | title={YOLOv4: Optimal Speed and Accuracy of Object Detection}, 18 | author={Alexey Bochkovskiy and Chien-Yao Wang and Hong-Yuan Mark Liao}, 19 | year={2020}, 20 | eprint={2004.10934}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 82 | -------------------------------------------------------------------------------- /results/generate_csv_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | results = { 6 | 'results-imagenet.csv': [ 7 | 'results-imagenet-real.csv', 8 | 'results-imagenetv2-matched-frequency.csv', 9 | 'results-sketch.csv' 10 | ], 11 | 'results-imagenet-a-clean.csv': [ 12 | 'results-imagenet-a.csv', 13 | ], 14 | 'results-imagenet-r-clean.csv': [ 15 | 'results-imagenet-r.csv', 16 | ], 17 | } 18 | 19 | 20 | def diff(base_df, test_csv): 21 | base_models = base_df['model'].values 22 | test_df = pd.read_csv(test_csv) 23 | test_models = test_df['model'].values 24 | 25 | rank_diff = np.zeros_like(test_models, dtype='object') 26 | top1_diff = np.zeros_like(test_models, dtype='object') 27 | top5_diff = np.zeros_like(test_models, dtype='object') 28 | 29 | for rank, model in enumerate(test_models): 30 | if model in base_models: 31 | base_rank = int(np.where(base_models == model)[0]) 32 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank] 33 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank] 34 | 35 | # rank_diff 36 | if rank == base_rank: 37 | rank_diff[rank] = f'0' 38 | elif rank > base_rank: 39 | rank_diff[rank] = f'-{rank - base_rank}' 40 | else: 41 | rank_diff[rank] = f'+{base_rank - rank}' 42 | 43 | # top1_diff 44 | if top1_d >= .0: 45 | top1_diff[rank] = f'+{top1_d:.3f}' 46 | else: 47 | top1_diff[rank] = f'-{abs(top1_d):.3f}' 48 | 49 | # top5_diff 50 | if top5_d >= .0: 51 | top5_diff[rank] = f'+{top5_d:.3f}' 52 | else: 53 | top5_diff[rank] = f'-{abs(top5_d):.3f}' 54 | 55 | else: 56 | rank_diff[rank] = '' 57 | top1_diff[rank] = '' 58 | top5_diff[rank] = '' 59 | 60 | test_df['top1_diff'] = top1_diff 61 | test_df['top5_diff'] = top5_diff 62 | test_df['rank_diff'] = rank_diff 63 | 64 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) 65 | test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 66 | test_df.to_csv(test_csv, index=False, float_format='%.3f') 67 | 68 | 69 | for base_results, test_results in results.items(): 70 | base_df = pd.read_csv(base_results) 71 | base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 72 | for test_csv in test_results: 73 | diff(base_df, test_csv) 74 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format) 75 | base_df.to_csv(base_results, index=False, float_format='%.3f') 76 | -------------------------------------------------------------------------------- /timm/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /timm/layers/pos_embed.py: -------------------------------------------------------------------------------- 1 | """ Position Embedding Utilities 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import logging 6 | import math 7 | from typing import List, Tuple, Optional, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .helpers import to_2tuple 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | def resample_abs_pos_embed( 18 | posemb, 19 | new_size: List[int], 20 | old_size: Optional[List[int]] = None, 21 | num_prefix_tokens: int = 1, 22 | interpolation: str = 'bicubic', 23 | antialias: bool = True, 24 | verbose: bool = False, 25 | ): 26 | # sort out sizes, assume square if old size not provided 27 | num_pos_tokens = posemb.shape[1] 28 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens 29 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: 30 | return posemb 31 | 32 | if old_size is None: 33 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) 34 | old_size = hw, hw 35 | 36 | if num_prefix_tokens: 37 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] 38 | else: 39 | posemb_prefix, posemb = None, posemb 40 | 41 | # do the interpolation 42 | embed_dim = posemb.shape[-1] 43 | orig_dtype = posemb.dtype 44 | posemb = posemb.float() # interpolate needs float32 45 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) 46 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 47 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) 48 | posemb = posemb.to(orig_dtype) 49 | 50 | # add back extra (class, etc) prefix tokens 51 | if posemb_prefix is not None: 52 | posemb = torch.cat([posemb_prefix, posemb], dim=1) 53 | 54 | if not torch.jit.is_scripting() and verbose: 55 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.') 56 | 57 | return posemb 58 | 59 | 60 | def resample_abs_pos_embed_nhwc( 61 | posemb, 62 | new_size: List[int], 63 | interpolation: str = 'bicubic', 64 | antialias: bool = True, 65 | verbose: bool = False, 66 | ): 67 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: 68 | return posemb 69 | 70 | orig_dtype = posemb.dtype 71 | posemb = posemb.float() 72 | # do the interpolation 73 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) 74 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 75 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) 76 | 77 | if not torch.jit.is_scripting() and verbose: 78 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') 79 | 80 | return posemb 81 | -------------------------------------------------------------------------------- /timm/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | from collections import OrderedDict 8 | from typing import Callable, Dict 9 | 10 | import torch 11 | from torch.optim.optimizer import Optimizer 12 | from collections import defaultdict 13 | 14 | 15 | class Lookahead(Optimizer): 16 | def __init__(self, base_optimizer, alpha=0.5, k=6): 17 | # NOTE super().__init__() not called on purpose 18 | self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() 19 | self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() 20 | if not 0.0 <= alpha <= 1.0: 21 | raise ValueError(f'Invalid slow update rate: {alpha}') 22 | if not 1 <= k: 23 | raise ValueError(f'Invalid lookahead steps: {k}') 24 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 25 | self._base_optimizer = base_optimizer 26 | self.param_groups = base_optimizer.param_groups 27 | self.defaults = base_optimizer.defaults 28 | self.defaults.update(defaults) 29 | self.state = defaultdict(dict) 30 | # manually add our defaults to the param groups 31 | for name, default in defaults.items(): 32 | for group in self._base_optimizer.param_groups: 33 | group.setdefault(name, default) 34 | 35 | @torch.no_grad() 36 | def update_slow(self, group): 37 | for fast_p in group["params"]: 38 | if fast_p.grad is None: 39 | continue 40 | param_state = self._base_optimizer.state[fast_p] 41 | if 'lookahead_slow_buff' not in param_state: 42 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 43 | param_state['lookahead_slow_buff'].copy_(fast_p) 44 | slow = param_state['lookahead_slow_buff'] 45 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 46 | fast_p.copy_(slow) 47 | 48 | def sync_lookahead(self): 49 | for group in self._base_optimizer.param_groups: 50 | self.update_slow(group) 51 | 52 | @torch.no_grad() 53 | def step(self, closure=None): 54 | loss = self._base_optimizer.step(closure) 55 | for group in self._base_optimizer.param_groups: 56 | group['lookahead_step'] += 1 57 | if group['lookahead_step'] % group['lookahead_k'] == 0: 58 | self.update_slow(group) 59 | return loss 60 | 61 | def state_dict(self): 62 | return self._base_optimizer.state_dict() 63 | 64 | def load_state_dict(self, state_dict): 65 | self._base_optimizer.load_state_dict(state_dict) 66 | self.param_groups = self._base_optimizer.param_groups 67 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that reads single tarfile based datasets 2 | 3 | This reader can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ReaderImageTar(Reader): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /docs/models/.templates/models/gloun-inception-v3.md: -------------------------------------------------------------------------------- 1 | # (Gluon) Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @article{DBLP:journals/corr/SzegedyVISW15, 17 | author = {Christian Szegedy and 18 | Vincent Vanhoucke and 19 | Sergey Ioffe and 20 | Jonathon Shlens and 21 | Zbigniew Wojna}, 22 | title = {Rethinking the Inception Architecture for Computer Vision}, 23 | journal = {CoRR}, 24 | volume = {abs/1512.00567}, 25 | year = {2015}, 26 | url = {http://arxiv.org/abs/1512.00567}, 27 | archivePrefix = {arXiv}, 28 | eprint = {1512.00567}, 29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 31 | bibsource = {dblp computer science bibliography, https://dblp.org} 32 | } 33 | ``` 34 | 35 | 79 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /docs/models/.templates/models/inception-v3.md: -------------------------------------------------------------------------------- 1 | # Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{DBLP:journals/corr/SzegedyVISW15, 15 | author = {Christian Szegedy and 16 | Vincent Vanhoucke and 17 | Sergey Ioffe and 18 | Jonathon Shlens and 19 | Zbigniew Wojna}, 20 | title = {Rethinking the Inception Architecture for Computer Vision}, 21 | journal = {CoRR}, 22 | volume = {abs/1512.00567}, 23 | year = {2015}, 24 | url = {http://arxiv.org/abs/1512.00567}, 25 | archivePrefix = {arXiv}, 26 | eprint = {1512.00567}, 27 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 28 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 29 | bibsource = {dblp computer science bibliography, https://dblp.org} 30 | } 31 | ``` 32 | 33 | 86 | -------------------------------------------------------------------------------- /timm/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | # Calculate symmetric padding for a convolution 13 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 14 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 15 | return padding 16 | 17 | 18 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 19 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): 20 | if isinstance(x, torch.Tensor): 21 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0) 22 | else: 23 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) 24 | 25 | 26 | # Can SAME padding for given args be done statically? 27 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 28 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 29 | 30 | 31 | def pad_same_arg( 32 | input_size: List[int], 33 | kernel_size: List[int], 34 | stride: List[int], 35 | dilation: List[int] = (1, 1), 36 | ) -> List[int]: 37 | ih, iw = input_size 38 | kh, kw = kernel_size 39 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0]) 40 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1]) 41 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 42 | 43 | 44 | # Dynamically pad input x with 'SAME' padding for conv with specified args 45 | def pad_same( 46 | x, 47 | kernel_size: List[int], 48 | stride: List[int], 49 | dilation: List[int] = (1, 1), 50 | value: float = 0, 51 | ): 52 | ih, iw = x.size()[-2:] 53 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) 54 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) 55 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value) 56 | return x 57 | 58 | 59 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 60 | dynamic = False 61 | if isinstance(padding, str): 62 | # for any string padding, the padding will be calculated for you, one of three ways 63 | padding = padding.lower() 64 | if padding == 'same': 65 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 66 | if is_static_pad(kernel_size, **kwargs): 67 | # static case, no extra overhead 68 | padding = get_padding(kernel_size, **kwargs) 69 | else: 70 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 71 | padding = 0 72 | dynamic = True 73 | elif padding == 'valid': 74 | # 'VALID' padding, same as padding=0 75 | padding = 0 76 | else: 77 | # Default to PyTorch style 'same'-ish symmetric padding 78 | padding = get_padding(kernel_size, **kwargs) 79 | return padding, dynamic 80 | -------------------------------------------------------------------------------- /docs/models/.templates/models/ese-vovnet.md: -------------------------------------------------------------------------------- 1 | # ESE-VoVNet 2 | 3 | **VoVNet** is a convolutional neural network that seeks to make [DenseNet](https://paperswithcode.com/method/densenet) more efficient by concatenating all features only once in the last feature map, which makes input size constant and enables enlarging new output channel. 4 | 5 | Read about [one-shot aggregation here](https://paperswithcode.com/method/one-shot-aggregation). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @misc{lee2019energy, 17 | title={An Energy and GPU-Computation Efficient Backbone Network for Real-Time Object Detection}, 18 | author={Youngwan Lee and Joong-won Hwang and Sangrok Lee and Yuseok Bae and Jongyoul Park}, 19 | year={2019}, 20 | eprint={1904.09730}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.CV} 23 | } 24 | ``` 25 | 26 | 93 | -------------------------------------------------------------------------------- /timm/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /docs/models/.templates/models/tf-inception-v3.md: -------------------------------------------------------------------------------- 1 | # (Tensorflow) Inception v3 2 | 3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module). 4 | 5 | The weights from this model were ported from [Tensorflow/Models](https://github.com/tensorflow/models). 6 | 7 | {% include 'code_snippets.md' %} 8 | 9 | ## How do I train this model? 10 | 11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 12 | 13 | ## Citation 14 | 15 | ```BibTeX 16 | @article{DBLP:journals/corr/SzegedyVISW15, 17 | author = {Christian Szegedy and 18 | Vincent Vanhoucke and 19 | Sergey Ioffe and 20 | Jonathon Shlens and 21 | Zbigniew Wojna}, 22 | title = {Rethinking the Inception Architecture for Computer Vision}, 23 | journal = {CoRR}, 24 | volume = {abs/1512.00567}, 25 | year = {2015}, 26 | url = {http://arxiv.org/abs/1512.00567}, 27 | archivePrefix = {arXiv}, 28 | eprint = {1512.00567}, 29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200}, 30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib}, 31 | bibsource = {dblp computer science bibliography, https://dblp.org} 32 | } 33 | ``` 34 | 35 | 88 | -------------------------------------------------------------------------------- /timm/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /docs/models/.templates/models/wide-resnet.md: -------------------------------------------------------------------------------- 1 | # Wide ResNet 2 | 3 | **Wide Residual Networks** are a variant on [ResNets](https://paperswithcode.com/method/resnet) where we decrease depth and increase the width of residual networks. This is achieved through the use of [wide residual blocks](https://paperswithcode.com/method/wide-residual-block). 4 | 5 | {% include 'code_snippets.md' %} 6 | 7 | ## How do I train this model? 8 | 9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 10 | 11 | ## Citation 12 | 13 | ```BibTeX 14 | @article{DBLP:journals/corr/ZagoruykoK16, 15 | author = {Sergey Zagoruyko and 16 | Nikos Komodakis}, 17 | title = {Wide Residual Networks}, 18 | journal = {CoRR}, 19 | volume = {abs/1605.07146}, 20 | year = {2016}, 21 | url = {http://arxiv.org/abs/1605.07146}, 22 | archivePrefix = {arXiv}, 23 | eprint = {1605.07146}, 24 | timestamp = {Mon, 13 Aug 2018 16:46:42 +0200}, 25 | biburl = {https://dblp.org/rec/journals/corr/ZagoruykoK16.bib}, 26 | bibsource = {dblp computer science bibliography, https://dblp.org} 27 | } 28 | ``` 29 | 30 | 103 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition 2 | from timm.layers.activations import * 3 | from timm.layers.adaptive_avgmax_pool import \ 4 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 5 | from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding 6 | from timm.layers.blur_pool import BlurPool2d 7 | from timm.layers.classifier import ClassifierHead, create_classifier 8 | from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer 9 | from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 10 | set_layer_config 11 | from timm.layers.conv2d_same import Conv2dSame, conv2d_same 12 | from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 13 | from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn 14 | from timm.layers.create_attn import get_attn, create_attn 15 | from timm.layers.create_conv2d import create_conv2d 16 | from timm.layers.create_norm import get_norm_layer, create_norm_layer 17 | from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 18 | from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path 19 | from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 20 | from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 22 | from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 23 | from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 24 | from timm.layers.gather_excite import GatherExcite 25 | from timm.layers.global_context import GlobalContext 26 | from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 27 | from timm.layers.inplace_abn import InplaceAbn 28 | from timm.layers.linear import Linear 29 | from timm.layers.mixed_conv2d import MixedConv2d 30 | from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp 31 | from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn 32 | from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 33 | from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 34 | from timm.layers.padding import get_padding, get_same_padding, pad_same 35 | from timm.layers.patch_embed import PatchEmbed 36 | from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d 37 | from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 38 | from timm.layers.selective_kernel import SelectiveKernel 39 | from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct 40 | from timm.layers.space_to_depth import SpaceToDepthModule 41 | from timm.layers.split_attn import SplitAttn 42 | from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 43 | from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 44 | from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool 45 | from timm.layers.trace_utils import _assert, _float_to_int 46 | from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 47 | 48 | import warnings 49 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", DeprecationWarning) 50 | -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .convnext import * 9 | from .crossvit import * 10 | from .cspnet import * 11 | from .davit import * 12 | from .deit import * 13 | from .densenet import * 14 | from .dla import * 15 | from .dpn import * 16 | from .edgenext import * 17 | from .efficientformer import * 18 | from .efficientformer_v2 import * 19 | from .efficientnet import * 20 | from .efficientvit_mit import * 21 | from .efficientvit_msra import * 22 | from .eva import * 23 | from .fastvit import * 24 | from .focalnet import * 25 | from .gcvit import * 26 | from .ghostnet import * 27 | from .hardcorenas import * 28 | from .hrnet import * 29 | from .inception_next import * 30 | from .inception_resnet_v2 import * 31 | from .inception_v3 import * 32 | from .inception_v4 import * 33 | from .levit import * 34 | from .maxxvit import * 35 | from .metaformer import * 36 | from .mlp_mixer import * 37 | from .mobilenetv3 import * 38 | from .mobilevit import * 39 | from .mvitv2 import * 40 | from .nasnet import * 41 | from .nest import * 42 | from .nfnet import * 43 | from .pit import * 44 | from .pnasnet import * 45 | from .pvt_v2 import * 46 | from .regnet import * 47 | from .repghost import * 48 | from .repvit import * 49 | from .res2net import * 50 | from .resnest import * 51 | from .resnet import * 52 | from .resnetv2 import * 53 | from .rexnet import * 54 | from .selecsls import * 55 | from .senet import * 56 | from .sequencer import * 57 | from .sknet import * 58 | from .swin_transformer import * 59 | from .swin_transformer_v2 import * 60 | from .swin_transformer_v2_cr import * 61 | from .tiny_vit import * 62 | from .tnt import * 63 | from .tresnet import * 64 | from .twins import * 65 | from .vgg import * 66 | from .visformer import * 67 | from .vision_transformer import * 68 | from .vision_transformer_hybrid import * 69 | from .vision_transformer_relpos import * 70 | from .vision_transformer_sam import * 71 | from .volo import * 72 | from .vovnet import * 73 | from .xception import * 74 | from .xception_aligned import * 75 | from .xcit import * 76 | 77 | from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \ 78 | set_pretrained_download_progress, set_pretrained_check_hash 79 | from ._factory import create_model, parse_model_name, safe_model_name 80 | from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet 81 | from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ 82 | register_notrace_module, is_notrace_module, get_notrace_modules, \ 83 | register_notrace_function, is_notrace_function, get_notrace_functions 84 | from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint 85 | from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub 86 | from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ 87 | group_modules, group_parameters, checkpoint_seq, adapt_input_conv 88 | from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg 89 | from ._prune import adapt_model_from_string 90 | from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \ 91 | register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \ 92 | is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value 93 | -------------------------------------------------------------------------------- /docs/models/.templates/models/ensemble-adversarial.md: -------------------------------------------------------------------------------- 1 | # # Ensemble Adversarial Inception ResNet v2 2 | 3 | **Inception-ResNet-v2** is a convolutional neural architecture that builds on the Inception family of architectures but incorporates [residual connections](https://paperswithcode.com/method/residual-connection) (replacing the filter concatenation stage of the Inception architecture). 4 | 5 | This particular model was trained for study of adversarial examples (adversarial training). 6 | 7 | The weights from this model were ported from [Tensorflow/Models](https://github.com/tensorflow/models). 8 | 9 | {% include 'code_snippets.md' %} 10 | 11 | ## How do I train this model? 12 | 13 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh. 14 | 15 | ## Citation 16 | 17 | ```BibTeX 18 | @article{DBLP:journals/corr/abs-1804-00097, 19 | author = {Alexey Kurakin and 20 | Ian J. Goodfellow and 21 | Samy Bengio and 22 | Yinpeng Dong and 23 | Fangzhou Liao and 24 | Ming Liang and 25 | Tianyu Pang and 26 | Jun Zhu and 27 | Xiaolin Hu and 28 | Cihang Xie and 29 | Jianyu Wang and 30 | Zhishuai Zhang and 31 | Zhou Ren and 32 | Alan L. Yuille and 33 | Sangxia Huang and 34 | Yao Zhao and 35 | Yuzhe Zhao and 36 | Zhonglin Han and 37 | Junjiajia Long and 38 | Yerkebulan Berdibekov and 39 | Takuya Akiba and 40 | Seiya Tokui and 41 | Motoki Abe}, 42 | title = {Adversarial Attacks and Defences Competition}, 43 | journal = {CoRR}, 44 | volume = {abs/1804.00097}, 45 | year = {2018}, 46 | url = {http://arxiv.org/abs/1804.00097}, 47 | archivePrefix = {arXiv}, 48 | eprint = {1804.00097}, 49 | timestamp = {Thu, 31 Oct 2019 16:31:22 +0100}, 50 | biburl = {https://dblp.org/rec/journals/corr/abs-1804-00097.bib}, 51 | bibsource = {dblp computer science bibliography, https://dblp.org} 52 | } 53 | ``` 54 | 55 | 99 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that extracts images from folders 2 | 3 | Folders are scanned recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | from typing import Dict, List, Optional, Set, Tuple, Union 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def find_images_and_targets( 19 | folder: str, 20 | types: Optional[Union[List, Tuple, Set]] = None, 21 | class_to_idx: Optional[Dict] = None, 22 | leaf_name_only: bool = True, 23 | sort: bool = True 24 | ): 25 | """ Walk folder recursively to discover images and map them to classes by folder names. 26 | 27 | Args: 28 | folder: root of folder to recrusively search 29 | types: types (file extensions) to search for in path 30 | class_to_idx: specify mapping for class (folder name) to class index if set 31 | leaf_name_only: use only leaf-name of folder walk for class names 32 | sort: re-sort found images by name (for consistent ordering) 33 | 34 | Returns: 35 | A list of image and target tuples, class_to_idx mapping 36 | """ 37 | types = get_img_extensions(as_set=True) if not types else set(types) 38 | labels = [] 39 | filenames = [] 40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 41 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 43 | for f in files: 44 | base, ext = os.path.splitext(f) 45 | if ext.lower() in types: 46 | filenames.append(os.path.join(root, f)) 47 | labels.append(label) 48 | if class_to_idx is None: 49 | # building class index 50 | unique_labels = set(labels) 51 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 54 | if sort: 55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 56 | return images_and_targets, class_to_idx 57 | 58 | 59 | class ReaderImageFolder(Reader): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | class_map=''): 65 | super().__init__() 66 | 67 | self.root = root 68 | class_to_idx = None 69 | if class_map: 70 | class_to_idx = load_class_map(class_map, root) 71 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) 72 | if len(self.samples) == 0: 73 | raise RuntimeError( 74 | f'Found 0 images in subfolders of {root}. ' 75 | f'Supported image extensions are {", ".join(get_img_extensions())}') 76 | 77 | def __getitem__(self, index): 78 | path, target = self.samples[index] 79 | return open(path, 'rb'), target 80 | 81 | def __len__(self): 82 | return len(self.samples) 83 | 84 | def _filename(self, index, basename=False, absolute=False): 85 | filename = self.samples[index][0] 86 | if basename: 87 | filename = os.path.basename(filename) 88 | elif not absolute: 89 | filename = os.path.relpath(filename, self.root) 90 | return filename 91 | -------------------------------------------------------------------------------- /timm/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | --------------------------------------------------------------------------------