├── timm ├── py.typed ├── version.py ├── data │ ├── readers │ │ ├── __init__.py │ │ ├── shared_count.py │ │ ├── reader.py │ │ ├── class_map.py │ │ ├── img_extensions.py │ │ ├── reader_factory.py │ │ ├── reader_image_tar.py │ │ ├── reader_hfds.py │ │ └── reader_image_folder.py │ ├── constants.py │ ├── _info │ │ ├── mini_imagenet_indices.txt │ │ ├── mini_imagenet_synsets.txt │ │ ├── imagenet_r_indices.txt │ │ ├── imagenet_a_indices.txt │ │ ├── imagenet_a_synsets.txt │ │ └── imagenet_r_synsets.txt │ ├── __init__.py │ ├── real_labels.py │ └── dataset_info.py ├── models │ ├── hub.py │ ├── factory.py │ ├── features.py │ ├── registry.py │ ├── fx_features.py │ ├── helpers.py │ ├── layers │ │ └── __init__.py │ ├── _features_fx.py │ └── _pretrained.py ├── utils │ ├── random.py │ ├── clip_grad.py │ ├── __init__.py │ ├── metrics.py │ ├── log.py │ ├── misc.py │ ├── summary.py │ ├── agc.py │ ├── decay_batch.py │ ├── jit.py │ ├── cuda.py │ ├── attention_extract.py │ └── onnx.py ├── loss │ ├── __init__.py │ ├── cross_entropy.py │ ├── jsd.py │ ├── binary_cross_entropy.py │ └── asymmetric_loss.py ├── scheduler │ ├── __init__.py │ ├── step_lr.py │ ├── multistep_lr.py │ ├── plateau_lr.py │ └── tanh_lr.py ├── layers │ ├── trace_utils.py │ ├── pool1d.py │ ├── linear.py │ ├── typing.py │ ├── space_to_depth.py │ ├── helpers.py │ ├── format.py │ ├── grn.py │ ├── layer_scale.py │ ├── create_conv2d.py │ ├── grid.py │ ├── median_pool.py │ ├── test_time_pool.py │ ├── create_norm.py │ ├── mixed_conv2d.py │ ├── _fx.py │ ├── interpolate.py │ ├── pos_embed.py │ ├── global_context.py │ ├── filter_response_norm.py │ ├── patch_dropout.py │ ├── conv_bn_act.py │ ├── padding.py │ ├── split_batchnorm.py │ ├── inplace_abn.py │ ├── pool2d_same.py │ ├── separable_conv.py │ ├── split_attn.py │ └── create_attn.py ├── optim │ ├── optim_factory.py │ ├── _types.py │ ├── __init__.py │ ├── sgdp.py │ └── lookahead.py ├── task │ ├── __init__.py │ ├── classification.py │ └── task.py └── __init__.py ├── tests └── __init__.py ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── trufflehog.yml │ ├── upload_pr_documentation.yml │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ └── tests.yml ├── requirements-dev.txt ├── MANIFEST.in ├── hubconf.py ├── requirements.txt ├── distributed_train.sh ├── hfdocs ├── source │ ├── reference │ │ ├── models.mdx │ │ ├── data.mdx │ │ ├── schedulers.mdx │ │ └── optimizers.mdx │ ├── index.mdx │ ├── hf_hub.mdx │ ├── installation.mdx │ └── hparams.mdx └── README.md ├── setup.cfg ├── CITATION.cff ├── CLAUDE.md ├── .gitignore ├── pyproject.toml ├── UPGRADING.md └── results └── generate_csv_results.py /timm/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.23.dev0' 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: rwightman 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-timeout 3 | pytest-xdist 4 | pytest-forked 5 | expecttest 6 | -------------------------------------------------------------------------------- /timm/data/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader_factory import create_reader 2 | from .img_extensions import * 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include timm/models/_pruned/*.txt 2 | include timm/data/_info/*.txt 3 | include timm/data/_info/*.json 4 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | import timm 3 | globals().update(timm.models._registry._model_entrypoints) 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision 3 | pyyaml 4 | huggingface_hub>=0.17.0 5 | safetensors>=0.2 6 | numpy 7 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /hfdocs/source/reference/models.mdx: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | [[autodoc]] timm.create_model 4 | 5 | [[autodoc]] timm.list_models 6 | -------------------------------------------------------------------------------- /timm/models/hub.py: -------------------------------------------------------------------------------- 1 | from ._hub import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from ._factory import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/features.py: -------------------------------------------------------------------------------- 1 | from ._features import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | from ._registry import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /timm/models/fx_features.py: -------------------------------------------------------------------------------- 1 | from ._features_fx import * 2 | 3 | import warnings 4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [dist_conda] 2 | 3 | conda_name_differences = 'torch:pytorch' 4 | channels = pytorch 5 | noarch = True 6 | 7 | [metadata] 8 | 9 | url = "https://github.com/huggingface/pytorch-image-models" -------------------------------------------------------------------------------- /hfdocs/source/reference/data.mdx: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | [[autodoc]] timm.data.create_dataset 4 | 5 | [[autodoc]] timm.data.create_loader 6 | 7 | [[autodoc]] timm.data.create_transform 8 | 9 | [[autodoc]] timm.data.resolve_data_config -------------------------------------------------------------------------------- /timm/utils/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def random_seed(seed=42, rank=0): 7 | torch.manual_seed(seed + rank) 8 | np.random.seed(seed + rank) 9 | random.seed(seed + rank) 10 | -------------------------------------------------------------------------------- /timm/models/helpers.py: -------------------------------------------------------------------------------- 1 | from ._builder import * 2 | from ._helpers import * 3 | from ._manipulate import * 4 | from ._prune import * 5 | 6 | import warnings 7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) 8 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel 2 | from .binary_cross_entropy import BinaryCrossEntropy 3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 4 | from .jsd import JsdCrossEntropy 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: Community Discussions 4 | url: https://github.com/rwightman/pytorch-image-models/discussions 5 | about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions. 6 | -------------------------------------------------------------------------------- /hfdocs/README.md: -------------------------------------------------------------------------------- 1 | # Hugging Face Timm Docs 2 | 3 | ## Getting Started 4 | 5 | ``` 6 | pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder 7 | pip install watchdog black 8 | ``` 9 | 10 | ## Preview the Docs Locally 11 | 12 | ``` 13 | doc-builder preview timm hfdocs/source 14 | ``` 15 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@main 16 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | message: "If you use this software, please cite it as below." 2 | title: "PyTorch Image Models" 3 | version: "1.2.2" 4 | doi: "10.5281/zenodo.4414861" 5 | authors: 6 | - family-names: Wightman 7 | given-names: Ross 8 | version: 1.0.11 9 | year: "2019" 10 | url: "https://github.com/huggingface/pytorch-image-models" 11 | license: "Apache 2.0" -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .multistep_lr import MultiStepLRScheduler 3 | from .plateau_lr import PlateauLRScheduler 4 | from .poly_lr import PolyLRScheduler 5 | from .step_lr import StepLRScheduler 6 | from .tanh_lr import TanhLRScheduler 7 | 8 | from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs 9 | -------------------------------------------------------------------------------- /timm/data/readers/shared_count.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value 2 | 3 | 4 | class SharedCount: 5 | def __init__(self, epoch: int = 0): 6 | self.shared_epoch = Value('i', epoch) 7 | 8 | @property 9 | def value(self): 10 | return self.shared_epoch.value 11 | 12 | @value.setter 13 | def value(self, epoch): 14 | self.shared_epoch.value = epoch 15 | -------------------------------------------------------------------------------- /timm/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /timm/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | # lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/ 2 | 3 | from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs 4 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group 5 | 6 | import warnings 7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning) 8 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: timm 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | DEFAULT_CROP_MODE = 'center' 3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 11 | -------------------------------------------------------------------------------- /timm/data/readers/reader.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class Reader: 5 | def __init__(self): 6 | pass 7 | 8 | @abstractmethod 9 | def _filename(self, index, basename=False, absolute=False): 10 | pass 11 | 12 | def filename(self, index, basename=False, absolute=False): 13 | return self._filename(index, basename=basename, absolute=absolute) 14 | 15 | def filenames(self, basename=False, absolute=False): 16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] 17 | 18 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: pytorch-image-models 16 | package_name: timm 17 | path_to_docs: pytorch-image-models/hfdocs/source 18 | version_tag_suffix: "" 19 | secrets: 20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 21 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: pytorch-image-models 17 | package_name: timm 18 | path_to_docs: pytorch-image-models/hfdocs/source 19 | version_tag_suffix: "" 20 | -------------------------------------------------------------------------------- /timm/task/__init__.py: -------------------------------------------------------------------------------- 1 | """Training task abstractions for timm. 2 | 3 | This module provides task-based abstractions for training loops where each task 4 | encapsulates both the forward pass and loss computation, returning a dictionary 5 | with loss components and outputs for logging. 6 | """ 7 | from .task import TrainingTask 8 | from .classification import ClassificationTask 9 | from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask 10 | 11 | __all__ = [ 12 | 'TrainingTask', 13 | 'ClassificationTask', 14 | 'DistillationTeacher', 15 | 'LogitDistillationTask', 16 | 'FeatureDistillationTask', 17 | ] 18 | -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ as __version__ 2 | from .layers import ( 3 | is_scriptable as is_scriptable, 4 | is_exportable as is_exportable, 5 | set_scriptable as set_scriptable, 6 | set_exportable as set_exportable, 7 | ) 8 | from .models import ( 9 | create_model as create_model, 10 | list_models as list_models, 11 | list_pretrained as list_pretrained, 12 | is_model as is_model, 13 | list_modules as list_modules, 14 | model_entrypoint as model_entrypoint, 15 | is_model_pretrained as is_model_pretrained, 16 | get_pretrained_cfg as get_pretrained_cfg, 17 | get_pretrained_cfg_value as get_pretrained_cfg_value, 18 | ) 19 | -------------------------------------------------------------------------------- /hfdocs/source/reference/schedulers.mdx: -------------------------------------------------------------------------------- 1 | # Learning Rate Schedulers 2 | 3 | This page contains the API reference documentation for learning rate schedulers included in `timm`. 4 | 5 | ## Schedulers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler 10 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2 11 | 12 | ### Scheduler Classes 13 | 14 | [[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler 15 | [[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler 16 | [[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler 17 | [[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler 18 | [[autodoc]] timm.scheduler.step_lr.StepLRScheduler 19 | [[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler 20 | -------------------------------------------------------------------------------- /timm/layers/pool1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def global_pool_nlc( 5 | x: torch.Tensor, 6 | pool_type: str = 'token', 7 | num_prefix_tokens: int = 1, 8 | reduce_include_prefix: bool = False, 9 | ): 10 | if not pool_type: 11 | return x 12 | 13 | if pool_type == 'token': 14 | x = x[:, 0] # class token 15 | else: 16 | x = x if reduce_include_prefix else x[:, num_prefix_tokens:] 17 | if pool_type == 'avg': 18 | x = x.mean(dim=1) 19 | elif pool_type == 'avgmax': 20 | x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) 21 | elif pool_type == 'max': 22 | x = x.amax(dim=1) 23 | else: 24 | assert not pool_type, f'Unknown pool type {pool_type}' 25 | 26 | return x -------------------------------------------------------------------------------- /timm/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project. Hparam requests, training help are not feature requests. 4 | The discussion forum is available for asking questions or seeking help from the community. 5 | title: "[FEATURE] Feature title..." 6 | labels: enhancement 7 | assignees: '' 8 | 9 | --- 10 | 11 | **Is your feature request related to a problem? Please describe.** 12 | A clear and concise description of what the problem is. 13 | 14 | **Describe the solution you'd like** 15 | A clear and concise description of what you want to happen. 16 | 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | 20 | **Additional context** 21 | Add any other context or screenshots about the feature request here. 22 | -------------------------------------------------------------------------------- /timm/optim/_types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, Union, Protocol, Type 2 | try: 3 | from typing import TypeAlias 4 | except ImportError: 5 | from typing_extensions import TypeAlias 6 | try: 7 | from typing import TypeVar 8 | except ImportError: 9 | from typing_extensions import TypeVar 10 | 11 | import torch 12 | import torch.optim 13 | 14 | try: 15 | from torch.optim.optimizer import ParamsT 16 | except (ImportError, TypeError): 17 | ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 18 | 19 | 20 | OptimType = Type[torch.optim.Optimizer] 21 | 22 | 23 | class OptimizerCallable(Protocol): 24 | """Protocol for optimizer constructor signatures.""" 25 | 26 | def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ... 27 | 28 | 29 | __all__ = ['ParamsT', 'OptimType', 'OptimizerCallable'] -------------------------------------------------------------------------------- /timm/utils/clip_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm.utils.agc import adaptive_clip_grad 4 | 5 | 6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): 7 | """ Dispatch to gradient clipping method 8 | 9 | Args: 10 | parameters (Iterable): model parameters to clip 11 | value (float): clipping value/factor/norm, mode dependant 12 | mode (str): clipping mode, one of 'norm', 'value', 'agc' 13 | norm_type (float): p-norm, default 2.0 14 | """ 15 | if mode == 'norm': 16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) 17 | elif mode == 'value': 18 | torch.nn.utils.clip_grad_value_(parameters, value) 19 | elif mode == 'agc': 20 | adaptive_clip_grad(parameters, value, norm_type=norm_type) 21 | else: 22 | assert False, f"Unknown clip mode ({mode})." 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .agc import adaptive_clip_grad 2 | from .attention_extract import AttentionExtract 3 | from .checkpoint_saver import CheckpointSaver 4 | from .clip_grad import dispatch_clip_grad 5 | from .cuda import ApexScaler, NativeScaler 6 | from .decay_batch import decay_batch_step, check_batch_size_retry 7 | from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ 8 | world_info_from_env, is_distributed_env, is_primary 9 | from .jit import set_jit_legacy, set_jit_fuser 10 | from .log import setup_default_logging, FormatterNoInfo 11 | from .metrics import AverageMeter, accuracy 12 | from .misc import natural_key, add_bool_arg, ParseKwargs 13 | from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model 14 | from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3 15 | from .random import random_seed 16 | from .summary import update_summary, get_outdir 17 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md - PyTorch Image Models (timm) 2 | 3 | ## Build/Test Commands 4 | - Install: `python -m pip install -e .` 5 | - Run tests: `pytest tests/` 6 | - Run specific test: `pytest tests/test_models.py::test_specific_function -v` 7 | - Run tests in parallel: `pytest -n 4 tests/` 8 | - Filter tests: `pytest -k "substring-to-match" tests/` 9 | 10 | ## Code Style Guidelines 11 | - Line length: 120 chars 12 | - Indentation: 4-space hanging indents, arguments should have an extra level of indent, use 'sadface' (closing parenthesis and colon on a separate line) 13 | - Typing: Use PEP484 type annotations in function signatures 14 | - Docstrings: Google style (do not duplicate type annotations and defaults) 15 | - Imports: Standard library first, then third-party, then local 16 | - Function naming: snake_case 17 | - Class naming: PascalCase 18 | - Error handling: Use try/except with specific exceptions 19 | - Conditional expressions: Use parentheses for complex expressions -------------------------------------------------------------------------------- /timm/data/_info/mini_imagenet_indices.txt: -------------------------------------------------------------------------------- 1 | 12 2 | 15 3 | 51 4 | 64 5 | 70 6 | 96 7 | 99 8 | 107 9 | 111 10 | 121 11 | 149 12 | 166 13 | 173 14 | 176 15 | 207 16 | 214 17 | 228 18 | 242 19 | 244 20 | 245 21 | 249 22 | 251 23 | 256 24 | 266 25 | 270 26 | 275 27 | 279 28 | 291 29 | 299 30 | 301 31 | 306 32 | 310 33 | 359 34 | 364 35 | 392 36 | 403 37 | 412 38 | 427 39 | 440 40 | 454 41 | 471 42 | 476 43 | 478 44 | 484 45 | 494 46 | 502 47 | 503 48 | 507 49 | 519 50 | 524 51 | 533 52 | 538 53 | 546 54 | 553 55 | 556 56 | 567 57 | 569 58 | 584 59 | 597 60 | 602 61 | 604 62 | 605 63 | 629 64 | 655 65 | 657 66 | 659 67 | 683 68 | 687 69 | 702 70 | 709 71 | 713 72 | 735 73 | 741 74 | 758 75 | 779 76 | 781 77 | 800 78 | 801 79 | 807 80 | 815 81 | 819 82 | 847 83 | 854 84 | 858 85 | 860 86 | 880 87 | 881 88 | 883 89 | 909 90 | 912 91 | 914 92 | 919 93 | 925 94 | 927 95 | 934 96 | 950 97 | 972 98 | 973 99 | 997 100 | 998 101 | -------------------------------------------------------------------------------- /timm/data/readers/class_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | 5 | def load_class_map(map_or_filename, root=''): 6 | if isinstance(map_or_filename, dict): 7 | assert dict, 'class_map dict must be non-empty' 8 | return map_or_filename 9 | class_map_path = map_or_filename 10 | if not os.path.exists(class_map_path): 11 | class_map_path = os.path.join(root, class_map_path) 12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename 13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower() 14 | if class_map_ext == '.txt': 15 | with open(class_map_path) as f: 16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)} 17 | elif class_map_ext == '.pkl': 18 | with open(class_map_path, 'rb') as f: 19 | class_to_idx = pickle.load(f) 20 | else: 21 | assert False, f'Unsupported class map file extension ({class_map_ext}).' 22 | return class_to_idx 23 | 24 | -------------------------------------------------------------------------------- /timm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ Eval metrics and related 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | 7 | class AverageMeter: 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the accuracy over the k top predictions for the specified values of k""" 27 | maxk = min(max(topk), output.size()[1]) 28 | batch_size = target.size(0) 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 33 | -------------------------------------------------------------------------------- /timm/layers/typing.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from functools import wraps 3 | from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager 4 | 5 | import torch 6 | 7 | __all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"] 8 | 9 | 10 | LayerType = Union[str, Callable, Type[torch.nn.Module]] 11 | PadType = Union[str, int, Tuple[int, int]] 12 | 13 | F = TypeVar("F", bound=Callable[..., object]) 14 | 15 | 16 | @overload 17 | def nullwrap(fn: F) -> F: ... # decorator form 18 | 19 | @overload 20 | def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form 21 | 22 | def nullwrap(fn: Optional[F] = None): 23 | # as a context manager 24 | if fn is None: 25 | return nullcontext() # `with nullwrap():` 26 | 27 | # as a decorator 28 | @wraps(fn) 29 | def wrapper(*args, **kwargs): 30 | return fn(*args, **kwargs) 31 | return wrapper # `@nullwrap` 32 | 33 | 34 | disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting 4 | features, the discussion forum is available for asking questions or seeking help 5 | from the community. 6 | title: "[BUG] Issue title..." 7 | labels: bug 8 | assignees: rwightman 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. 18 | 2. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Windows 10, Ubuntu 18.04] 28 | - This repository version [e.g. pip 0.3.1 or commit ref] 29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /timm/utils/log.py: -------------------------------------------------------------------------------- 1 | """ Logging helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import logging 6 | import logging.handlers 7 | 8 | 9 | class FormatterNoInfo(logging.Formatter): 10 | def __init__(self, fmt='%(levelname)s: %(message)s'): 11 | logging.Formatter.__init__(self, fmt) 12 | 13 | def format(self, record): 14 | if record.levelno == logging.INFO: 15 | return str(record.getMessage()) 16 | return logging.Formatter.format(self, record) 17 | 18 | 19 | def setup_default_logging(default_level=logging.INFO, log_path=''): 20 | console_handler = logging.StreamHandler() 21 | console_handler.setFormatter(FormatterNoInfo()) 22 | logging.root.addHandler(console_handler) 23 | logging.root.setLevel(default_level) 24 | if log_path: 25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) 26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") 27 | file_handler.setFormatter(file_formatter) 28 | logging.root.addHandler(file_handler) 29 | -------------------------------------------------------------------------------- /hfdocs/source/reference/optimizers.mdx: -------------------------------------------------------------------------------- 1 | # Optimization 2 | 3 | This page contains the API reference documentation for learning rate optimizers included in `timm`. 4 | 5 | ## Optimizers 6 | 7 | ### Factory functions 8 | 9 | [[autodoc]] timm.optim.create_optimizer_v2 10 | [[autodoc]] timm.optim.list_optimizers 11 | [[autodoc]] timm.optim.get_optimizer_class 12 | 13 | ### Optimizer Classes 14 | 15 | [[autodoc]] timm.optim.adabelief.AdaBelief 16 | [[autodoc]] timm.optim.adafactor.Adafactor 17 | [[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision 18 | [[autodoc]] timm.optim.adahessian.Adahessian 19 | [[autodoc]] timm.optim.adamp.AdamP 20 | [[autodoc]] timm.optim.adan.Adan 21 | [[autodoc]] timm.optim.adopt.Adopt 22 | [[autodoc]] timm.optim.lamb.Lamb 23 | [[autodoc]] timm.optim.laprop.LaProp 24 | [[autodoc]] timm.optim.lars.Lars 25 | [[autodoc]] timm.optim.lion.Lion 26 | [[autodoc]] timm.optim.lookahead.Lookahead 27 | [[autodoc]] timm.optim.madgrad.MADGRAD 28 | [[autodoc]] timm.optim.mars.Mars 29 | [[autodoc]] timm.optim.nadamw.NAdamW 30 | [[autodoc]] timm.optim.nvnovograd.NvNovoGrad 31 | [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF 32 | [[autodoc]] timm.optim.sgdp.SGDP 33 | [[autodoc]] timm.optim.sgdw.SGDW -------------------------------------------------------------------------------- /timm/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | bs: torch.jit.Final[int] 7 | 8 | def __init__(self, block_size: int = 4): 9 | super().__init__() 10 | assert block_size == 4 11 | self.bs = block_size 12 | 13 | def forward(self, x): 14 | N, C, H, W = x.size() 15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 18 | return x 19 | 20 | 21 | class DepthToSpace(nn.Module): 22 | 23 | def __init__(self, block_size): 24 | super().__init__() 25 | self.bs = block_size 26 | 27 | def forward(self, x): 28 | N, C, H, W = x.size() 29 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 30 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 31 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 32 | return x 33 | -------------------------------------------------------------------------------- /timm/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ Misc utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import argparse 6 | import ast 7 | import re 8 | 9 | 10 | def natural_key(string_): 11 | """See http://www.codinghorror.com/blog/archives/001018.html""" 12 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 13 | 14 | 15 | def add_bool_arg(parser, name, default=False, help=''): 16 | dest_name = name.replace('-', '_') 17 | group = parser.add_mutually_exclusive_group(required=False) 18 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) 19 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) 20 | parser.set_defaults(**{dest_name: default}) 21 | 22 | 23 | class ParseKwargs(argparse.Action): 24 | def __call__(self, parser, namespace, values, option_string=None): 25 | kw = {} 26 | for value in values: 27 | key, value = value.split('=') 28 | try: 29 | kw[key] = ast.literal_eval(value) 30 | except ValueError: 31 | kw[key] = str(value) # fallback to string (avoid need to escape on command line) 32 | setattr(namespace, self.dest, kw) 33 | -------------------------------------------------------------------------------- /timm/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return tuple(x) 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pads a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Cross Entropy w/ smoothing or soft targets 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class LabelSmoothingCrossEntropy(nn.Module): 12 | """ NLL loss with label smoothing. 13 | """ 14 | def __init__(self, smoothing=0.1): 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 2 | rand_augment_transform, auto_augment_transform 3 | from .config import resolve_data_config, resolve_model_data_config 4 | from .constants import * 5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset 6 | from .dataset_factory import create_dataset 7 | from .dataset_info import DatasetInfo, CustomDatasetInfo 8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset 9 | from .loader import create_loader 10 | from .mixup import Mixup, FastCollateMixup 11 | from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size 12 | from .naflex_loader import create_naflex_loader 13 | from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size 14 | from .naflex_transforms import ( 15 | ResizeToSequence, 16 | CenterCropToSequence, 17 | RandomCropToSequence, 18 | RandomResizedCropToSequence, 19 | ResizeKeepRatioToSequence, 20 | Patchify, 21 | patchify_image, 22 | ) 23 | from .readers import create_reader 24 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions 25 | from .real_labels import RealLabelsImagenet 26 | from .transforms import * 27 | from .transforms_factory import create_transform 28 | -------------------------------------------------------------------------------- /timm/layers/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | class Format(str, Enum): 8 | NCHW = 'NCHW' 9 | NHWC = 'NHWC' 10 | NCL = 'NCL' 11 | NLC = 'NLC' 12 | 13 | 14 | FormatT = Union[str, Format] 15 | 16 | 17 | def get_spatial_dim(fmt: FormatT): 18 | fmt = Format(fmt) 19 | if fmt is Format.NLC: 20 | dim = (1,) 21 | elif fmt is Format.NCL: 22 | dim = (2,) 23 | elif fmt is Format.NHWC: 24 | dim = (1, 2) 25 | else: 26 | dim = (2, 3) 27 | return dim 28 | 29 | 30 | def get_channel_dim(fmt: FormatT): 31 | fmt = Format(fmt) 32 | if fmt is Format.NHWC: 33 | dim = 3 34 | elif fmt is Format.NLC: 35 | dim = 2 36 | else: 37 | dim = 1 38 | return dim 39 | 40 | 41 | def nchw_to(x: torch.Tensor, fmt: Format): 42 | if fmt == Format.NHWC: 43 | x = x.permute(0, 2, 3, 1) 44 | elif fmt == Format.NLC: 45 | x = x.flatten(2).transpose(1, 2) 46 | elif fmt == Format.NCL: 47 | x = x.flatten(2) 48 | return x 49 | 50 | 51 | def nhwc_to(x: torch.Tensor, fmt: Format): 52 | if fmt == Format.NCHW: 53 | x = x.permute(0, 3, 1, 2) 54 | elif fmt == Format.NLC: 55 | x = x.flatten(1, 2) 56 | elif fmt == Format.NCL: 57 | x = x.flatten(1, 2).transpose(1, 2) 58 | return x 59 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabelief import AdaBelief 2 | from .adafactor import Adafactor 3 | from .adafactor_bv import AdafactorBigVision 4 | from .adahessian import Adahessian 5 | from .adamp import AdamP 6 | from .adamw import AdamWLegacy 7 | from .adan import Adan 8 | from .adopt import Adopt 9 | from .lamb import Lamb 10 | from .laprop import LaProp 11 | from .lars import Lars 12 | from .lion import Lion 13 | from .lookahead import Lookahead 14 | from .madgrad import MADGRAD 15 | from .mars import Mars 16 | from .muon import Muon 17 | from .nadam import NAdamLegacy 18 | from .nadamw import NAdamW 19 | from .nvnovograd import NvNovoGrad 20 | from .radam import RAdamLegacy 21 | from .rmsprop_tf import RMSpropTF 22 | from .sgdp import SGDP 23 | from .sgdw import SGDW 24 | 25 | # bring common torch.optim Optimizers into timm.optim namespace for consistency 26 | from torch.optim import Adadelta, Adagrad, Adamax, Adam, AdamW, RMSprop, SGD 27 | try: 28 | # in case any very old torch versions being used 29 | from torch.optim import NAdam, RAdam 30 | except ImportError: 31 | pass 32 | 33 | from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ 34 | create_optimizer_v2, create_optimizer, optimizer_kwargs 35 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers 36 | -------------------------------------------------------------------------------- /timm/utils/summary.py: -------------------------------------------------------------------------------- 1 | """ Summary utilities 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import csv 6 | import os 7 | from collections import OrderedDict 8 | try: 9 | import wandb 10 | except ImportError: 11 | pass 12 | 13 | 14 | def get_outdir(path, *paths, inc=False): 15 | outdir = os.path.join(path, *paths) 16 | if not os.path.exists(outdir): 17 | os.makedirs(outdir) 18 | elif inc: 19 | count = 1 20 | outdir_inc = outdir + '-' + str(count) 21 | while os.path.exists(outdir_inc): 22 | count = count + 1 23 | outdir_inc = outdir + '-' + str(count) 24 | assert count < 100 25 | outdir = outdir_inc 26 | os.makedirs(outdir) 27 | return outdir 28 | 29 | 30 | def update_summary( 31 | epoch, 32 | train_metrics, 33 | eval_metrics, 34 | filename, 35 | lr=None, 36 | write_header=False, 37 | log_wandb=False, 38 | ): 39 | rowd = OrderedDict(epoch=epoch) 40 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 41 | if eval_metrics: 42 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 43 | if lr is not None: 44 | rowd['lr'] = lr 45 | if log_wandb: 46 | wandb.log(rowd) 47 | with open(filename, mode='a') as cf: 48 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 49 | if write_header: # first iteration (epoch == 1 can't be used) 50 | dw.writeheader() 51 | dw.writerow(rowd) 52 | -------------------------------------------------------------------------------- /timm/data/_info/mini_imagenet_synsets.txt: -------------------------------------------------------------------------------- 1 | n01532829 2 | n01558993 3 | n01704323 4 | n01749939 5 | n01770081 6 | n01843383 7 | n01855672 8 | n01910747 9 | n01930112 10 | n01981276 11 | n02074367 12 | n02089867 13 | n02091244 14 | n02091831 15 | n02099601 16 | n02101006 17 | n02105505 18 | n02108089 19 | n02108551 20 | n02108915 21 | n02110063 22 | n02110341 23 | n02111277 24 | n02113712 25 | n02114548 26 | n02116738 27 | n02120079 28 | n02129165 29 | n02138441 30 | n02165456 31 | n02174001 32 | n02219486 33 | n02443484 34 | n02457408 35 | n02606052 36 | n02687172 37 | n02747177 38 | n02795169 39 | n02823428 40 | n02871525 41 | n02950826 42 | n02966193 43 | n02971356 44 | n02981792 45 | n03017168 46 | n03047690 47 | n03062245 48 | n03075370 49 | n03127925 50 | n03146219 51 | n03207743 52 | n03220513 53 | n03272010 54 | n03337140 55 | n03347037 56 | n03400231 57 | n03417042 58 | n03476684 59 | n03527444 60 | n03535780 61 | n03544143 62 | n03584254 63 | n03676483 64 | n03770439 65 | n03773504 66 | n03775546 67 | n03838899 68 | n03854065 69 | n03888605 70 | n03908618 71 | n03924679 72 | n03980874 73 | n03998194 74 | n04067472 75 | n04146614 76 | n04149813 77 | n04243546 78 | n04251144 79 | n04258138 80 | n04275548 81 | n04296562 82 | n04389033 83 | n04418357 84 | n04435653 85 | n04443257 86 | n04509417 87 | n04515003 88 | n04522168 89 | n04596742 90 | n04604644 91 | n04612504 92 | n06794110 93 | n07584110 94 | n07613480 95 | n07697537 96 | n07747607 97 | n09246464 98 | n09256479 99 | n13054560 100 | n13133613 101 | -------------------------------------------------------------------------------- /hfdocs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # timm 2 | 3 | 4 | 5 | `timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts. 6 | 7 | It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use. 8 | 9 | Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library. 10 | 11 |
12 |
13 |
Tutorials
15 |

Learn the basics and become familiar with timm. Start here if you are using timm for the first time!

16 |
17 |
Reference
19 |

Technical descriptions of how timm classes and methods work.

20 |
21 |
22 |
23 | -------------------------------------------------------------------------------- /timm/layers/grn.py: -------------------------------------------------------------------------------- 1 | """ Global Response Normalization Module 2 | 3 | Based on the GRN layer presented in 4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 5 | 6 | This implementation 7 | * works for both NCHW and NHWC tensor layouts 8 | * uses affine param names matching existing torch norm layers 9 | * slightly improves eager mode performance via fused addcmul 10 | 11 | Hacked together by / Copyright 2023 Ross Wightman 12 | """ 13 | 14 | import torch 15 | from torch import nn as nn 16 | 17 | 18 | class GlobalResponseNorm(nn.Module): 19 | """ Global Response Normalization layer 20 | """ 21 | def __init__( 22 | self, 23 | dim: int, 24 | eps: float = 1e-6, 25 | channels_last: bool = True, 26 | device=None, 27 | dtype=None, 28 | ): 29 | dd = {'device': device, 'dtype': dtype} 30 | super().__init__() 31 | self.eps = eps 32 | if channels_last: 33 | self.spatial_dim = (1, 2) 34 | self.channel_dim = -1 35 | self.wb_shape = (1, 1, 1, -1) 36 | else: 37 | self.spatial_dim = (2, 3) 38 | self.channel_dim = 1 39 | self.wb_shape = (1, -1, 1, 1) 40 | 41 | self.weight = nn.Parameter(torch.zeros(dim, **dd)) 42 | self.bias = nn.Parameter(torch.zeros(dim, **dd)) 43 | 44 | def forward(self, x): 45 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) 46 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) 47 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n) 48 | -------------------------------------------------------------------------------- /timm/data/readers/img_extensions.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions'] 4 | 5 | 6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use 7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync 8 | 9 | 10 | def _set_extensions(extensions): 11 | global IMG_EXTENSIONS 12 | global _IMG_EXTENSIONS_SET 13 | dedupe = set() # NOTE de-duping tuple while keeping original order 14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x)) 15 | _IMG_EXTENSIONS_SET = set(extensions) 16 | 17 | 18 | def _valid_extension(x: str): 19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.') 20 | 21 | 22 | def is_img_extension(ext): 23 | return ext in _IMG_EXTENSIONS_SET 24 | 25 | 26 | def get_img_extensions(as_set=False): 27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS) 28 | 29 | 30 | def set_img_extensions(extensions): 31 | assert len(extensions) 32 | for x in extensions: 33 | assert _valid_extension(x) 34 | _set_extensions(extensions) 35 | 36 | 37 | def add_img_extensions(ext): 38 | if not isinstance(ext, (list, tuple, set)): 39 | ext = (ext,) 40 | for x in ext: 41 | assert _valid_extension(x) 42 | extensions = IMG_EXTENSIONS + tuple(ext) 43 | _set_extensions(extensions) 44 | 45 | 46 | def del_img_extensions(ext): 47 | if not isinstance(ext, (list, tuple, set)): 48 | ext = (ext,) 49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext) 50 | _set_extensions(extensions) 51 | -------------------------------------------------------------------------------- /timm/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LayerScale(nn.Module): 6 | """ LayerScale on tensors with channels in last-dim. 7 | """ 8 | def __init__( 9 | self, 10 | dim: int, 11 | init_values: float = 1e-5, 12 | inplace: bool = False, 13 | device=None, 14 | dtype=None, 15 | ) -> None: 16 | super().__init__() 17 | self.init_values = init_values 18 | self.inplace = inplace 19 | self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) 20 | 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | torch.nn.init.constant_(self.gamma, self.init_values) 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | 29 | 30 | class LayerScale2d(nn.Module): 31 | """ LayerScale for tensors with torch 2D NCHW layout. 32 | """ 33 | def __init__( 34 | self, 35 | dim: int, 36 | init_values: float = 1e-5, 37 | inplace: bool = False, 38 | device=None, 39 | dtype=None, 40 | ): 41 | super().__init__() 42 | self.init_values = init_values 43 | self.inplace = inplace 44 | self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) 45 | 46 | self.reset_parameters() 47 | 48 | def reset_parameters(self): 49 | torch.nn.init.constant_(self.gamma, self.init_values) 50 | 51 | def forward(self, x): 52 | gamma = self.gamma.view(1, -1, 1, 1) 53 | return x.mul_(gamma) if self.inplace else x * gamma 54 | 55 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /timm/utils/agc.py: -------------------------------------------------------------------------------- 1 | """ Adaptive Gradient Clipping 2 | 3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171): 4 | 5 | @article{brock2021high, 6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, 7 | title={High-Performance Large-Scale Image Recognition Without Normalization}, 8 | journal={arXiv preprint arXiv:}, 9 | year={2021} 10 | } 11 | 12 | Code references: 13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets 14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c 15 | 16 | Hacked together by / Copyright 2021 Ross Wightman 17 | """ 18 | import torch 19 | 20 | 21 | def unitwise_norm(x, norm_type=2.0): 22 | if x.ndim <= 1: 23 | return x.norm(norm_type) 24 | else: 25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor 26 | # might need special cases for other weights (possibly MHA) where this may not be true 27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) 28 | 29 | 30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): 31 | if isinstance(parameters, torch.Tensor): 32 | parameters = [parameters] 33 | for p in parameters: 34 | if p.grad is None: 35 | continue 36 | p_data = p.detach() 37 | g_data = p.grad.detach() 38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) 39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type) 40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) 41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) 42 | p.grad.detach().copy_(new_grads) 43 | -------------------------------------------------------------------------------- /timm/layers/grid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: 7 | """generate N-D grid in dimension order. 8 | 9 | The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. 10 | 11 | That is, the statement 12 | [X1,X2,X3] = ndgrid(x1,x2,x3) 13 | 14 | produces the same result as 15 | 16 | [X2,X1,X3] = meshgrid(x2,x1,x3) 17 | 18 | This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make 19 | torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). 20 | 21 | """ 22 | try: 23 | return torch.meshgrid(*tensors, indexing='ij') 24 | except TypeError: 25 | # old PyTorch < 1.10 will follow this path as it does not have indexing arg, 26 | # the old behaviour of meshgrid was 'ij' 27 | return torch.meshgrid(*tensors) 28 | 29 | 30 | def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]: 31 | """generate N-D grid in spatial dim order. 32 | 33 | The meshgrid function is similar to ndgrid except that the order of the 34 | first two input and output arguments is switched. 35 | 36 | That is, the statement 37 | 38 | [X,Y,Z] = meshgrid(x,y,z) 39 | produces the same result as 40 | 41 | [Y,X,Z] = ndgrid(y,x,z) 42 | Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space, 43 | while ndgrid is better suited to multidimensional problems that aren't spatially based. 44 | """ 45 | 46 | # NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have 47 | # capability of generating grid in xy order before then. 48 | return torch.meshgrid(*tensors, indexing='xy') 49 | 50 | -------------------------------------------------------------------------------- /timm/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super().__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /timm/utils/decay_batch.py: -------------------------------------------------------------------------------- 1 | """ Batch size decay and retry helpers. 2 | 3 | Copyright 2022 Ross Wightman 4 | """ 5 | import math 6 | 7 | 8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False): 9 | """ power of two batch-size decay with intra steps 10 | 11 | Decay by stepping between powers of 2: 12 | * determine power-of-2 floor of current batch size (base batch size) 13 | * divide above value by num_intra_steps to determine step size 14 | * floor batch_size to nearest multiple of step_size (from base batch size) 15 | Examples: 16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1 17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2 18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1 19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1 20 | """ 21 | if batch_size <= 1: 22 | # return 0 for stopping value so easy to use in loop 23 | return 0 24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2))) 25 | step_size = max(base_batch_size // num_intra_steps, 1) 26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size 27 | if no_odd and batch_size % 2: 28 | batch_size -= 1 29 | return batch_size 30 | 31 | 32 | def check_batch_size_retry(error_str): 33 | """ check failure error string for conditions where batch decay retry should not be attempted 34 | """ 35 | error_str = error_str.lower() 36 | if 'required rank' in error_str: 37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's 38 | # not compatible with channels_last memory format. 39 | return False 40 | if 'illegal' in error_str: 41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state 42 | return False 43 | return True 44 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_indices.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 2 3 | 4 4 | 6 5 | 8 6 | 9 7 | 11 8 | 13 9 | 22 10 | 23 11 | 26 12 | 29 13 | 31 14 | 39 15 | 47 16 | 63 17 | 71 18 | 76 19 | 79 20 | 84 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 100 27 | 105 28 | 107 29 | 113 30 | 122 31 | 125 32 | 130 33 | 132 34 | 144 35 | 145 36 | 147 37 | 148 38 | 150 39 | 151 40 | 155 41 | 160 42 | 161 43 | 162 44 | 163 45 | 171 46 | 172 47 | 178 48 | 187 49 | 195 50 | 199 51 | 203 52 | 207 53 | 208 54 | 219 55 | 231 56 | 232 57 | 234 58 | 235 59 | 242 60 | 245 61 | 247 62 | 250 63 | 251 64 | 254 65 | 259 66 | 260 67 | 263 68 | 265 69 | 267 70 | 269 71 | 276 72 | 277 73 | 281 74 | 288 75 | 289 76 | 291 77 | 292 78 | 293 79 | 296 80 | 299 81 | 301 82 | 308 83 | 309 84 | 310 85 | 311 86 | 314 87 | 315 88 | 319 89 | 323 90 | 327 91 | 330 92 | 334 93 | 335 94 | 337 95 | 338 96 | 340 97 | 341 98 | 344 99 | 347 100 | 353 101 | 355 102 | 361 103 | 362 104 | 365 105 | 366 106 | 367 107 | 368 108 | 372 109 | 388 110 | 390 111 | 393 112 | 397 113 | 401 114 | 407 115 | 413 116 | 414 117 | 425 118 | 428 119 | 430 120 | 435 121 | 437 122 | 441 123 | 447 124 | 448 125 | 457 126 | 462 127 | 463 128 | 469 129 | 470 130 | 471 131 | 472 132 | 476 133 | 483 134 | 487 135 | 515 136 | 546 137 | 555 138 | 558 139 | 570 140 | 579 141 | 583 142 | 587 143 | 593 144 | 594 145 | 596 146 | 609 147 | 613 148 | 617 149 | 621 150 | 629 151 | 637 152 | 657 153 | 658 154 | 701 155 | 717 156 | 724 157 | 763 158 | 768 159 | 774 160 | 776 161 | 779 162 | 780 163 | 787 164 | 805 165 | 812 166 | 815 167 | 820 168 | 824 169 | 833 170 | 847 171 | 852 172 | 866 173 | 875 174 | 883 175 | 889 176 | 895 177 | 907 178 | 928 179 | 931 180 | 932 181 | 933 182 | 934 183 | 936 184 | 937 185 | 943 186 | 945 187 | 947 188 | 948 189 | 949 190 | 951 191 | 953 192 | 954 193 | 957 194 | 963 195 | 965 196 | 967 197 | 980 198 | 981 199 | 983 200 | 988 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_indices.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 11 3 | 13 4 | 15 5 | 17 6 | 22 7 | 23 8 | 27 9 | 30 10 | 37 11 | 39 12 | 42 13 | 47 14 | 50 15 | 57 16 | 70 17 | 71 18 | 76 19 | 79 20 | 89 21 | 90 22 | 94 23 | 96 24 | 97 25 | 99 26 | 105 27 | 107 28 | 108 29 | 110 30 | 113 31 | 124 32 | 125 33 | 130 34 | 132 35 | 143 36 | 144 37 | 150 38 | 151 39 | 207 40 | 234 41 | 235 42 | 254 43 | 277 44 | 283 45 | 287 46 | 291 47 | 295 48 | 298 49 | 301 50 | 306 51 | 307 52 | 308 53 | 309 54 | 310 55 | 311 56 | 313 57 | 314 58 | 315 59 | 317 60 | 319 61 | 323 62 | 324 63 | 326 64 | 327 65 | 330 66 | 334 67 | 335 68 | 336 69 | 347 70 | 361 71 | 363 72 | 372 73 | 378 74 | 386 75 | 397 76 | 400 77 | 401 78 | 402 79 | 404 80 | 407 81 | 411 82 | 416 83 | 417 84 | 420 85 | 425 86 | 428 87 | 430 88 | 437 89 | 438 90 | 445 91 | 456 92 | 457 93 | 461 94 | 462 95 | 470 96 | 472 97 | 483 98 | 486 99 | 488 100 | 492 101 | 496 102 | 514 103 | 516 104 | 528 105 | 530 106 | 539 107 | 542 108 | 543 109 | 549 110 | 552 111 | 557 112 | 561 113 | 562 114 | 569 115 | 572 116 | 573 117 | 575 118 | 579 119 | 589 120 | 606 121 | 607 122 | 609 123 | 614 124 | 626 125 | 627 126 | 640 127 | 641 128 | 642 129 | 643 130 | 658 131 | 668 132 | 677 133 | 682 134 | 684 135 | 687 136 | 701 137 | 704 138 | 719 139 | 736 140 | 746 141 | 749 142 | 752 143 | 758 144 | 763 145 | 765 146 | 768 147 | 773 148 | 774 149 | 776 150 | 779 151 | 780 152 | 786 153 | 792 154 | 797 155 | 802 156 | 803 157 | 804 158 | 813 159 | 815 160 | 820 161 | 823 162 | 831 163 | 833 164 | 835 165 | 839 166 | 845 167 | 847 168 | 850 169 | 859 170 | 862 171 | 870 172 | 879 173 | 880 174 | 888 175 | 890 176 | 897 177 | 900 178 | 907 179 | 913 180 | 924 181 | 932 182 | 933 183 | 934 184 | 937 185 | 943 186 | 945 187 | 947 188 | 951 189 | 954 190 | 956 191 | 957 192 | 959 193 | 971 194 | 972 195 | 980 196 | 981 197 | 984 198 | 986 199 | 987 200 | 988 201 | -------------------------------------------------------------------------------- /timm/data/real_labels.py: -------------------------------------------------------------------------------- 1 | """ Real labels evaluator for ImageNet 2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import os 8 | import json 9 | import numpy as np 10 | import pkgutil 11 | 12 | 13 | class RealLabelsImagenet: 14 | 15 | def __init__(self, filenames, real_json=None, topk=(1, 5)): 16 | if real_json is not None: 17 | with open(real_json) as real_labels: 18 | real_labels = json.load(real_labels) 19 | else: 20 | real_labels = json.loads( 21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8')) 22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} 23 | self.real_labels = real_labels 24 | self.filenames = filenames 25 | assert len(self.filenames) == len(self.real_labels) 26 | self.topk = topk 27 | self.is_correct = {k: [] for k in topk} 28 | self.sample_idx = 0 29 | 30 | def add_result(self, output): 31 | maxk = max(self.topk) 32 | _, pred_batch = output.topk(maxk, 1, True, True) 33 | pred_batch = pred_batch.cpu().numpy() 34 | for pred in pred_batch: 35 | filename = self.filenames[self.sample_idx] 36 | filename = os.path.basename(filename) 37 | if self.real_labels[filename]: 38 | for k in self.topk: 39 | self.is_correct[k].append( 40 | any([p in self.real_labels[filename] for p in pred[:k]])) 41 | self.sample_idx += 1 42 | 43 | def get_accuracy(self, k=None): 44 | if k is None: 45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} 46 | else: 47 | return float(np.mean(self.is_correct[k])) * 100 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # PyCharm 101 | .idea 102 | 103 | output/ 104 | 105 | # PyTorch weights 106 | *.tar 107 | *.pth 108 | *.pt 109 | *.torch 110 | *.gz 111 | Untitled.ipynb 112 | Testing notebook.ipynb 113 | 114 | # Root dir exclusions 115 | /*.csv 116 | /*.yaml 117 | /*.json 118 | /*.jpg 119 | /*.png 120 | /*.zip 121 | /*.tar.* -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | from typing import List 10 | 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | class StepLRScheduler(Scheduler): 16 | """ 17 | """ 18 | 19 | def __init__( 20 | self, 21 | optimizer: torch.optim.Optimizer, 22 | decay_t: float, 23 | decay_rate: float = 1., 24 | warmup_t=0, 25 | warmup_lr_init=0, 26 | warmup_prefix=True, 27 | t_in_epochs=True, 28 | noise_range_t=None, 29 | noise_pct=0.67, 30 | noise_std=1.0, 31 | noise_seed=42, 32 | initialize=True, 33 | ) -> None: 34 | super().__init__( 35 | optimizer, 36 | param_group_field="lr", 37 | t_in_epochs=t_in_epochs, 38 | noise_range_t=noise_range_t, 39 | noise_pct=noise_pct, 40 | noise_std=noise_std, 41 | noise_seed=noise_seed, 42 | initialize=initialize, 43 | ) 44 | 45 | self.decay_t = decay_t 46 | self.decay_rate = decay_rate 47 | self.warmup_t = warmup_t 48 | self.warmup_lr_init = warmup_lr_init 49 | self.warmup_prefix = warmup_prefix 50 | if self.warmup_t: 51 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 52 | super().update_groups(self.warmup_lr_init) 53 | else: 54 | self.warmup_steps = [1 for _ in self.base_values] 55 | 56 | def _get_lr(self, t: int) -> List[float]: 57 | if t < self.warmup_t: 58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 59 | else: 60 | if self.warmup_prefix: 61 | t = t - self.warmup_t 62 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 63 | return lrs 64 | -------------------------------------------------------------------------------- /hfdocs/source/hf_hub.mdx: -------------------------------------------------------------------------------- 1 | # Sharing and Loading Models From the Hugging Face Hub 2 | 3 | The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub. 4 | 5 | In this short guide, we'll see how to: 6 | 1. Share a `timm` model on the Hub 7 | 2. How to load that model back from the Hub 8 | 9 | ## Authenticating 10 | 11 | First, you'll need to make sure you have the `huggingface_hub` package installed. 12 | 13 | ```bash 14 | pip install huggingface_hub 15 | ``` 16 | 17 | Then, you'll need to authenticate yourself. You can do this by running the following command: 18 | 19 | ```bash 20 | huggingface-cli login 21 | ``` 22 | 23 | Or, if you're using a notebook, you can use the `notebook_login` helper: 24 | 25 | ```py 26 | >>> from huggingface_hub import notebook_login 27 | >>> notebook_login() 28 | ``` 29 | 30 | ## Sharing a Model 31 | 32 | ```py 33 | >>> import timm 34 | >>> model = timm.create_model('resnet18', pretrained=True, num_classes=4) 35 | ``` 36 | 37 | Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial. 38 | 39 | Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function. 40 | 41 | ```py 42 | >>> model_cfg = dict(label_names=['a', 'b', 'c', 'd']) 43 | >>> timm.models.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg) 44 | ``` 45 | 46 | Running the above would push the model to `/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code! 47 | 48 | ## Loading a Model 49 | 50 | Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub. 51 | 52 | ```py 53 | >>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True) 54 | ``` 55 | -------------------------------------------------------------------------------- /timm/data/readers/reader_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | from .reader_image_folder import ReaderImageFolder 5 | from .reader_image_in_tar import ReaderImageInTar 6 | 7 | 8 | def create_reader( 9 | name: str, 10 | root: Optional[str] = None, 11 | split: str = 'train', 12 | **kwargs, 13 | ): 14 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 15 | name = name.lower() 16 | name = name.split('/', 1) 17 | prefix = '' 18 | if len(name) > 1: 19 | prefix = name[0] 20 | name = name[-1] 21 | 22 | # FIXME the additional features are only supported by ReaderHfds for now. 23 | additional_features = kwargs.pop("additional_features", None) 24 | 25 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to 26 | # explicitly select other options shortly 27 | if prefix == 'hfds': 28 | from .reader_hfds import ReaderHfds # defer Hf datasets import 29 | reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs) 30 | elif prefix == 'hfids': 31 | from .reader_hfids import ReaderHfids # defer HF datasets import 32 | reader = ReaderHfids(name=name, root=root, split=split, **kwargs) 33 | elif prefix == 'tfds': 34 | from .reader_tfds import ReaderTfds # defer tensorflow import 35 | reader = ReaderTfds(name=name, root=root, split=split, **kwargs) 36 | elif prefix == 'wds': 37 | from .reader_wds import ReaderWds 38 | kwargs.pop('download', False) 39 | reader = ReaderWds(root=root, name=name, split=split, **kwargs) 40 | else: 41 | assert os.path.exists(root) 42 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder 43 | # FIXME support split here or in reader? 44 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': 45 | reader = ReaderImageInTar(root, **kwargs) 46 | else: 47 | reader = ReaderImageFolder(root, **kwargs) 48 | return reader 49 | -------------------------------------------------------------------------------- /timm/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super().__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /timm/scheduler/multistep_lr.py: -------------------------------------------------------------------------------- 1 | """ MultiStep LR Scheduler 2 | 3 | Basic multi step LR schedule with warmup, noise. 4 | """ 5 | import torch 6 | import bisect 7 | from timm.scheduler.scheduler import Scheduler 8 | from typing import List 9 | 10 | class MultiStepLRScheduler(Scheduler): 11 | """ 12 | """ 13 | 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | decay_t: List[int], 18 | decay_rate: float = 1., 19 | warmup_t=0, 20 | warmup_lr_init=0, 21 | warmup_prefix=True, 22 | t_in_epochs=True, 23 | noise_range_t=None, 24 | noise_pct=0.67, 25 | noise_std=1.0, 26 | noise_seed=42, 27 | initialize=True, 28 | ) -> None: 29 | super().__init__( 30 | optimizer, 31 | param_group_field="lr", 32 | t_in_epochs=t_in_epochs, 33 | noise_range_t=noise_range_t, 34 | noise_pct=noise_pct, 35 | noise_std=noise_std, 36 | noise_seed=noise_seed, 37 | initialize=initialize, 38 | ) 39 | 40 | self.decay_t = decay_t 41 | self.decay_rate = decay_rate 42 | self.warmup_t = warmup_t 43 | self.warmup_lr_init = warmup_lr_init 44 | self.warmup_prefix = warmup_prefix 45 | if self.warmup_t: 46 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 47 | super().update_groups(self.warmup_lr_init) 48 | else: 49 | self.warmup_steps = [1 for _ in self.base_values] 50 | 51 | def get_curr_decay_steps(self, t): 52 | # find where in the array t goes, 53 | # assumes self.decay_t is sorted 54 | return bisect.bisect_right(self.decay_t, t + 1) 55 | 56 | def _get_lr(self, t: int) -> List[float]: 57 | if t < self.warmup_t: 58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 59 | else: 60 | if self.warmup_prefix: 61 | t = t - self.warmup_t 62 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] 63 | return lrs 64 | -------------------------------------------------------------------------------- /hfdocs/source/installation.mdx: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**. 4 | 5 | ## Virtual Environment 6 | 7 | You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts. 8 | 9 | 1. Create and navigate to your project directory: 10 | 11 | ```bash 12 | mkdir ~/my-project 13 | cd ~/my-project 14 | ``` 15 | 16 | 2. Start a virtual environment inside your directory: 17 | 18 | ```bash 19 | python -m venv .env 20 | ``` 21 | 22 | 3. Activate and deactivate the virtual environment with the following commands: 23 | 24 | ```bash 25 | # Activate the virtual environment 26 | source .env/bin/activate 27 | 28 | # Deactivate the virtual environment 29 | source .env/bin/deactivate 30 | ``` 31 | 32 | Once you've created your virtual environment, you can install `timm` in it. 33 | 34 | ## Using pip 35 | 36 | The most straightforward way to install `timm` is with pip: 37 | 38 | ```bash 39 | pip install timm 40 | ``` 41 | 42 | Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version: 43 | 44 | ```bash 45 | pip install git+https://github.com/rwightman/pytorch-image-models.git 46 | ``` 47 | 48 | Run the following command to check if `timm` has been properly installed: 49 | 50 | ```bash 51 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 52 | ``` 53 | 54 | This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output: 55 | 56 | ```python 57 | ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384'] 58 | ``` 59 | 60 | ## From Source 61 | 62 | Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands: 63 | 64 | ```bash 65 | git clone https://github.com/rwightman/pytorch-image-models.git 66 | cd pytorch-image-models 67 | pip install -e . 68 | ``` 69 | 70 | Again, you can check if `timm` was properly installed with the following command: 71 | 72 | ```bash 73 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])" 74 | ``` 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "timm" 7 | authors = [ 8 | {name = "Ross Wightman", email = "ross@huggingface.co"}, 9 | ] 10 | description = "PyTorch Image Models" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | keywords = ["pytorch", "image-classification"] 14 | license = {text = "Apache-2.0"} 15 | classifiers = [ 16 | 'Development Status :: 5 - Production/Stable', 17 | 'Intended Audience :: Education', 18 | 'Intended Audience :: Science/Research', 19 | 'License :: OSI Approved :: Apache Software License', 20 | 'Programming Language :: Python :: 3.8', 21 | 'Programming Language :: Python :: 3.9', 22 | 'Programming Language :: Python :: 3.10', 23 | 'Programming Language :: Python :: 3.11', 24 | 'Programming Language :: Python :: 3.12', 25 | 'Topic :: Scientific/Engineering', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'Topic :: Software Development', 28 | 'Topic :: Software Development :: Libraries', 29 | 'Topic :: Software Development :: Libraries :: Python Modules', 30 | ] 31 | dependencies = [ 32 | 'torch', 33 | 'torchvision', 34 | 'pyyaml', 35 | 'huggingface_hub', 36 | 'safetensors', 37 | ] 38 | dynamic = ["version"] 39 | 40 | [project.urls] 41 | homepage = "https://github.com/huggingface/pytorch-image-models" 42 | documentation = "https://huggingface.co/docs/timm/en/index" 43 | repository = "https://github.com/huggingface/pytorch-image-models" 44 | 45 | [tool.pdm.dev-dependencies] 46 | test = [ 47 | 'pytest', 48 | 'pytest-timeout', 49 | 'pytest-xdist', 50 | 'pytest-forked', 51 | 'expecttest', 52 | ] 53 | 54 | [tool.pdm.version] 55 | source = "file" 56 | path = "timm/version.py" 57 | 58 | [tool.pytest.ini_options] 59 | testpaths = ['tests'] 60 | markers = [ 61 | "base: marker for model tests using the basic setup", 62 | "cfg: marker for model tests checking the config", 63 | "torchscript: marker for model tests using torchscript", 64 | "features: marker for model tests checking feature extraction", 65 | "fxforward: marker for model tests using torch fx (only forward)", 66 | "fxbackward: marker for model tests using torch fx (only backward)", 67 | ] -------------------------------------------------------------------------------- /timm/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import functools 8 | import types 9 | from typing import Type 10 | 11 | import torch.nn as nn 12 | 13 | from .norm import ( 14 | GroupNorm, 15 | GroupNorm1, 16 | LayerNorm, 17 | LayerNorm2d, 18 | LayerNormFp32, 19 | LayerNorm2dFp32, 20 | RmsNorm, 21 | RmsNorm2d, 22 | RmsNormFp32, 23 | RmsNorm2dFp32, 24 | SimpleNorm, 25 | SimpleNorm2d, 26 | SimpleNormFp32, 27 | SimpleNorm2dFp32, 28 | ) 29 | from torchvision.ops.misc import FrozenBatchNorm2d 30 | 31 | _NORM_MAP = dict( 32 | batchnorm=nn.BatchNorm2d, 33 | batchnorm2d=nn.BatchNorm2d, 34 | batchnorm1d=nn.BatchNorm1d, 35 | groupnorm=GroupNorm, 36 | groupnorm1=GroupNorm1, 37 | layernorm=LayerNorm, 38 | layernorm2d=LayerNorm2d, 39 | layernormfp32=LayerNormFp32, 40 | layernorm2dfp32=LayerNorm2dFp32, 41 | rmsnorm=RmsNorm, 42 | rmsnorm2d=RmsNorm2d, 43 | rmsnormfp32=RmsNormFp32, 44 | rmsnorm2dfp32=RmsNorm2dFp32, 45 | simplenorm=SimpleNorm, 46 | simplenorm2d=SimpleNorm2d, 47 | simplenormfp32=SimpleNormFp32, 48 | simplenorm2dfp32=SimpleNorm2dFp32, 49 | frozenbatchnorm2d=FrozenBatchNorm2d, 50 | ) 51 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 52 | 53 | 54 | def create_norm_layer(layer_name, num_features, **kwargs): 55 | layer = get_norm_layer(layer_name) 56 | layer_instance = layer(num_features, **kwargs) 57 | return layer_instance 58 | 59 | 60 | def get_norm_layer(norm_layer): 61 | if norm_layer is None: 62 | return None 63 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 64 | norm_kwargs = {} 65 | 66 | # unbind partial fn, so args can be rebound later 67 | if isinstance(norm_layer, functools.partial): 68 | norm_kwargs.update(norm_layer.keywords) 69 | norm_layer = norm_layer.func 70 | 71 | if isinstance(norm_layer, str): 72 | if not norm_layer: 73 | return None 74 | layer_name = norm_layer.replace('_', '').lower() 75 | norm_layer = _NORM_MAP[layer_name] 76 | else: 77 | norm_layer = norm_layer 78 | 79 | if norm_kwargs: 80 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 81 | return norm_layer 82 | -------------------------------------------------------------------------------- /timm/utils/jit.py: -------------------------------------------------------------------------------- 1 | """ JIT scripting/tracing utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import os 6 | 7 | import torch 8 | 9 | 10 | def set_jit_legacy(): 11 | """ Set JIT executor to legacy w/ support for op fusion 12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes 13 | in the JIT executor. These API are not supported so could change. 14 | """ 15 | # 16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" 17 | torch._C._jit_set_profiling_executor(False) 18 | torch._C._jit_set_profiling_mode(False) 19 | torch._C._jit_override_can_fuse_on_gpu(True) 20 | #torch._C._jit_set_texpr_fuser_enabled(True) 21 | 22 | 23 | def set_jit_fuser(fuser): 24 | if fuser == "te": 25 | # default fuser should be == 'te' 26 | torch._C._jit_set_profiling_executor(True) 27 | torch._C._jit_set_profiling_mode(True) 28 | torch._C._jit_override_can_fuse_on_cpu(False) 29 | torch._C._jit_override_can_fuse_on_gpu(True) 30 | torch._C._jit_set_texpr_fuser_enabled(True) 31 | try: 32 | torch._C._jit_set_nvfuser_enabled(False) 33 | except Exception: 34 | pass 35 | elif fuser == "old" or fuser == "legacy": 36 | torch._C._jit_set_profiling_executor(False) 37 | torch._C._jit_set_profiling_mode(False) 38 | torch._C._jit_override_can_fuse_on_gpu(True) 39 | torch._C._jit_set_texpr_fuser_enabled(False) 40 | try: 41 | torch._C._jit_set_nvfuser_enabled(False) 42 | except Exception: 43 | pass 44 | elif fuser == "nvfuser" or fuser == "nvf": 45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' 46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' 47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' 48 | torch._C._jit_set_texpr_fuser_enabled(False) 49 | torch._C._jit_set_profiling_executor(True) 50 | torch._C._jit_set_profiling_mode(True) 51 | torch._C._jit_can_fuse_on_cpu() 52 | torch._C._jit_can_fuse_on_gpu() 53 | torch._C._jit_override_can_fuse_on_cpu(False) 54 | torch._C._jit_override_can_fuse_on_gpu(False) 55 | torch._C._jit_set_nvfuser_guard_mode(True) 56 | torch._C._jit_set_nvfuser_enabled(True) 57 | else: 58 | assert False, f"Invalid jit fuser ({fuser})" 59 | -------------------------------------------------------------------------------- /timm/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | from typing import List, Union 8 | 9 | import torch 10 | from torch import nn as nn 11 | 12 | from .conv2d_same import create_conv2d_pad 13 | 14 | 15 | def _split_channels(num_chan, num_groups): 16 | split = [num_chan // num_groups for _ in range(num_groups)] 17 | split[0] += num_chan - sum(split) 18 | return split 19 | 20 | 21 | class MixedConv2d(nn.ModuleDict): 22 | """ Mixed Grouped Convolution 23 | 24 | Based on MDConv and GroupedConv in MixNet impl: 25 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 26 | """ 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | kernel_size: Union[int, List[int]] = 3, 32 | stride: int = 1, 33 | padding: str = '', 34 | dilation: int = 1, 35 | depthwise: bool = False, 36 | **kwargs 37 | ): 38 | super().__init__() 39 | 40 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 41 | num_groups = len(kernel_size) 42 | in_splits = _split_channels(in_channels, num_groups) 43 | out_splits = _split_channels(out_channels, num_groups) 44 | self.in_channels = sum(in_splits) 45 | self.out_channels = sum(out_splits) 46 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 47 | conv_groups = in_ch if depthwise else 1 48 | # use add_module to keep key space clean 49 | self.add_module( 50 | str(idx), 51 | create_conv2d_pad( 52 | in_ch, 53 | out_ch, 54 | k, 55 | stride=stride, 56 | padding=padding, 57 | dilation=dilation, 58 | groups=conv_groups, 59 | **kwargs, 60 | ) 61 | ) 62 | self.splits = in_splits 63 | 64 | def forward(self, x): 65 | x_split = torch.split(x, self.splits, 1) 66 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 67 | x = torch.cat(x_out, 1) 68 | return x 69 | -------------------------------------------------------------------------------- /timm/utils/cuda.py: -------------------------------------------------------------------------------- 1 | """ CUDA / AMP utils 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | 7 | try: 8 | from apex import amp 9 | has_apex = True 10 | except ImportError: 11 | amp = None 12 | has_apex = False 13 | 14 | from .clip_grad import dispatch_clip_grad 15 | 16 | 17 | class ApexScaler: 18 | state_dict_key = "amp" 19 | 20 | def __call__( 21 | self, 22 | loss, 23 | optimizer, 24 | clip_grad=None, 25 | clip_mode='norm', 26 | parameters=None, 27 | create_graph=False, 28 | need_update=True, 29 | ): 30 | with amp.scale_loss(loss, optimizer) as scaled_loss: 31 | scaled_loss.backward(create_graph=create_graph) 32 | if need_update: 33 | if clip_grad is not None: 34 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 35 | optimizer.step() 36 | 37 | def state_dict(self): 38 | if 'state_dict' in amp.__dict__: 39 | return amp.state_dict() 40 | 41 | def load_state_dict(self, state_dict): 42 | if 'load_state_dict' in amp.__dict__: 43 | amp.load_state_dict(state_dict) 44 | 45 | 46 | class NativeScaler: 47 | state_dict_key = "amp_scaler" 48 | 49 | def __init__(self, device='cuda'): 50 | try: 51 | self._scaler = torch.amp.GradScaler(device=device) 52 | except (AttributeError, TypeError) as e: 53 | self._scaler = torch.cuda.amp.GradScaler() 54 | 55 | def __call__( 56 | self, 57 | loss, 58 | optimizer, 59 | clip_grad=None, 60 | clip_mode='norm', 61 | parameters=None, 62 | create_graph=False, 63 | need_update=True, 64 | ): 65 | self._scaler.scale(loss).backward(create_graph=create_graph) 66 | if need_update: 67 | if clip_grad is not None: 68 | assert parameters is not None 69 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 70 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 71 | self._scaler.step(optimizer) 72 | self._scaler.update() 73 | 74 | def state_dict(self): 75 | return self._scaler.state_dict() 76 | 77 | def load_state_dict(self, state_dict): 78 | self._scaler.load_state_dict(state_dict) 79 | -------------------------------------------------------------------------------- /timm/layers/_fx.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Union, Tuple, Type 2 | 3 | import torch 4 | from torch import nn 5 | 6 | try: 7 | # NOTE we wrap torchvision fns to use timm leaf / no trace definitions 8 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor 9 | from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names 10 | has_fx_feature_extraction = True 11 | except ImportError: 12 | has_fx_feature_extraction = False 13 | 14 | 15 | __all__ = [ 16 | 'register_notrace_module', 17 | 'is_notrace_module', 18 | 'get_notrace_modules', 19 | 'register_notrace_function', 20 | 'is_notrace_function', 21 | 'get_notrace_functions', 22 | 'create_feature_extractor', 23 | 'get_graph_node_names', 24 | ] 25 | 26 | # modules to treat as leafs when tracing 27 | _leaf_modules = set() 28 | 29 | 30 | def register_notrace_module(module: Type[nn.Module]): 31 | """ 32 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it. 33 | """ 34 | _leaf_modules.add(module) 35 | return module 36 | 37 | 38 | def is_notrace_module(module: Type[nn.Module]): 39 | return module in _leaf_modules 40 | 41 | 42 | def get_notrace_modules(): 43 | return list(_leaf_modules) 44 | 45 | 46 | # Functions we want to autowrap (treat them as leaves) 47 | _autowrap_functions = set() 48 | 49 | 50 | def register_notrace_function(name_or_fn): 51 | _autowrap_functions.add(name_or_fn) 52 | return name_or_fn 53 | 54 | 55 | def is_notrace_function(func: Callable): 56 | return func in _autowrap_functions 57 | 58 | 59 | def get_notrace_functions(): 60 | return list(_autowrap_functions) 61 | 62 | 63 | def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]: 64 | return _get_graph_node_names( 65 | model, 66 | tracer_kwargs={ 67 | 'leaf_modules': list(_leaf_modules), 68 | 'autowrap_functions': list(_autowrap_functions) 69 | } 70 | ) 71 | 72 | 73 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): 74 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 75 | return _create_feature_extractor( 76 | model, return_nodes, 77 | tracer_kwargs={ 78 | 'leaf_modules': list(_leaf_modules), 79 | 'autowrap_functions': list(_autowrap_functions) 80 | } 81 | ) -------------------------------------------------------------------------------- /UPGRADING.md: -------------------------------------------------------------------------------- 1 | # Upgrading from previous versions 2 | 3 | I generally try to maintain code interface and especially model weight compatibility across many `timm` versions. Sometimes there are exceptions. 4 | 5 | ## Checkpoint remapping 6 | 7 | Pretrained weight remapping is handled by `checkpoint_filter_fn` in a model implementation module. This remaps old pretrained checkpoints to new, and also 3rd party (original) checkpoints to `timm` format if the model was modified when brought into `timm`. 8 | 9 | The `checkpoint_filter_fn` is automatically called when loading pretrained weights via `pretrained=True`, but they can be called manually if you call the fn directly with the current model instance and old state dict. 10 | 11 | ## Upgrading from 0.6 and earlier 12 | 13 | Many changes were made since the 0.6.x stable releases. They were previewed in 0.8.x dev releases but not everyone transitioned. 14 | * `timm.models.layers` moved to `timm.layers`: 15 | * `from timm.models.layers import name` will still work via deprecation mapping (but please transition to `timm.layers`). 16 | * `import timm.models.layers.module` or `from timm.models.layers.module import name` needs to be changed now. 17 | * Builder, helper, non-model modules in `timm.models` have a `_` prefix added, ie `timm.models.helpers` -> `timm.models._helpers`, there are temporary deprecation mapping files but those will be removed. 18 | * All models now support `architecture.pretrained_tag` naming (ex `resnet50.rsb_a1`). 19 | * The pretrained_tag is the specific weight variant (different head) for the architecture. 20 | * Using only `architecture` defaults to the first weights in the default_cfgs for that model architecture. 21 | * In adding pretrained tags, many model names that existed to differentiate were renamed to use the tag (ex: `vit_base_patch16_224_in21k` -> `vit_base_patch16_224.augreg_in21k`). There are deprecation mappings for these. 22 | * A number of models had their checkpoints remapped to match architecture changes needed to better support `features_only=True`, there are `checkpoint_filter_fn` methods in any model module that was remapped. These can be passed to `timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)` to remap your existing checkpoint. 23 | * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license. 24 | * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version. 25 | -------------------------------------------------------------------------------- /timm/data/dataset_info.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Union 3 | 4 | 5 | class DatasetInfo(ABC): 6 | 7 | def __init__(self): 8 | pass 9 | 10 | @abstractmethod 11 | def num_classes(self): 12 | pass 13 | 14 | @abstractmethod 15 | def label_names(self): 16 | pass 17 | 18 | @abstractmethod 19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 20 | pass 21 | 22 | @abstractmethod 23 | def index_to_label_name(self, index) -> str: 24 | pass 25 | 26 | @abstractmethod 27 | def index_to_description(self, index: int, detailed: bool = False) -> str: 28 | pass 29 | 30 | @abstractmethod 31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 32 | pass 33 | 34 | 35 | class CustomDatasetInfo(DatasetInfo): 36 | """ DatasetInfo that wraps passed values for custom datasets.""" 37 | 38 | def __init__( 39 | self, 40 | label_names: Union[List[str], Dict[int, str]], 41 | label_descriptions: Optional[Dict[str, str]] = None 42 | ): 43 | super().__init__() 44 | assert len(label_names) > 0 45 | self._label_names = label_names # label index => label name mapping 46 | self._label_descriptions = label_descriptions # label name => label description mapping 47 | if self._label_descriptions is not None: 48 | # validate descriptions (label names required) 49 | assert isinstance(self._label_descriptions, dict) 50 | for n in self._label_names: 51 | assert n in self._label_descriptions 52 | 53 | def num_classes(self): 54 | return len(self._label_names) 55 | 56 | def label_names(self): 57 | return self._label_names 58 | 59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: 60 | return self._label_descriptions 61 | 62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str: 63 | if self._label_descriptions: 64 | return self._label_descriptions[label] 65 | return label # return label name itself if a descriptions is not present 66 | 67 | def index_to_label_name(self, index) -> str: 68 | assert 0 <= index < len(self._label_names) 69 | return self._label_names[index] 70 | 71 | def index_to_description(self, index: int, detailed: bool = False) -> str: 72 | label = self.index_to_label_name(index) 73 | return self.label_name_to_description(label, detailed=detailed) 74 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | OMP_NUM_THREADS: 2 11 | MKL_NUM_THREADS: 2 12 | 13 | jobs: 14 | test: 15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest] 19 | python: ['3.10', '3.13'] 20 | torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.9.1', vision: '0.24.1'}] 21 | testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward'] 22 | exclude: 23 | - python: '3.13' 24 | torch: {base: '1.13.0', vision: '0.14.0'} 25 | runs-on: ${{ matrix.os }} 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | - name: Set up Python ${{ matrix.python }} 30 | uses: actions/setup-python@v1 31 | with: 32 | python-version: ${{ matrix.python }} 33 | - name: Install testing dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install -r requirements-dev.txt 37 | - name: Install torch on mac 38 | if: startsWith(matrix.os, 'macOS') 39 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 40 | - name: Install torch on Windows 41 | if: startsWith(matrix.os, 'windows') 42 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }} 43 | - name: Install torch on ubuntu 44 | if: startsWith(matrix.os, 'ubuntu') 45 | run: | 46 | sudo sed -i 's/azure\.//' /etc/apt/sources.list 47 | sudo apt update 48 | sudo apt install -y google-perftools 49 | pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu --index-url https://download.pytorch.org/whl/cpu 50 | - name: Install requirements 51 | run: | 52 | pip install -r requirements.txt 53 | - name: Force old numpy for old torch 54 | if: ${{ matrix.torch.base == '1.13.0' }} 55 | run: pip install --upgrade 'numpy<2.0' 56 | - name: Run tests on Windows 57 | if: startsWith(matrix.os, 'windows') 58 | env: 59 | PYTHONDONTWRITEBYTECODE: 1 60 | run: | 61 | pytest -vv tests 62 | - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac 63 | if: ${{ !startsWith(matrix.os, 'windows') }} 64 | env: 65 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 66 | PYTHONDONTWRITEBYTECODE: 1 67 | run: | 68 | pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests 69 | -------------------------------------------------------------------------------- /timm/layers/interpolate.py: -------------------------------------------------------------------------------- 1 | """ Interpolation helpers for timm layers 2 | 3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations 4 | Copyright Shane Barratt, Apache 2.0 license 5 | """ 6 | import torch 7 | from itertools import product 8 | 9 | 10 | class RegularGridInterpolator: 11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing. 12 | Produces similar results to scipy RegularGridInterpolator or interp2d 13 | in 'linear' mode. 14 | 15 | Taken from https://github.com/sbarratt/torch_interpolations 16 | """ 17 | 18 | def __init__(self, points, values): 19 | self.points = points 20 | self.values = values 21 | 22 | assert isinstance(self.points, tuple) or isinstance(self.points, list) 23 | assert isinstance(self.values, torch.Tensor) 24 | 25 | self.ms = list(self.values.shape) 26 | self.n = len(self.points) 27 | 28 | assert len(self.ms) == self.n 29 | 30 | for i, p in enumerate(self.points): 31 | assert isinstance(p, torch.Tensor) 32 | assert p.shape[0] == self.values.shape[i] 33 | 34 | def __call__(self, points_to_interp): 35 | assert self.points is not None 36 | assert self.values is not None 37 | 38 | assert len(points_to_interp) == len(self.points) 39 | K = points_to_interp[0].shape[0] 40 | for x in points_to_interp: 41 | assert x.shape[0] == K 42 | 43 | idxs = [] 44 | dists = [] 45 | overalls = [] 46 | for p, x in zip(self.points, points_to_interp): 47 | idx_right = torch.bucketize(x, p) 48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) 50 | dist_left = x - p[idx_left] 51 | dist_right = p[idx_right] - x 52 | dist_left[dist_left < 0] = 0. 53 | dist_right[dist_right < 0] = 0. 54 | both_zero = (dist_left == 0) & (dist_right == 0) 55 | dist_left[both_zero] = dist_right[both_zero] = 1. 56 | 57 | idxs.append((idx_left, idx_right)) 58 | dists.append((dist_left, dist_right)) 59 | overalls.append(dist_left + dist_right) 60 | 61 | numerator = 0. 62 | for indexer in product([0, 1], repeat=self.n): 63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] 64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] 65 | numerator += self.values[as_s] * \ 66 | torch.prod(torch.stack(bs_s), dim=0) 67 | denominator = torch.prod(torch.stack(overalls), dim=0) 68 | return numerator / denominator 69 | -------------------------------------------------------------------------------- /timm/loss/binary_cross_entropy.py: -------------------------------------------------------------------------------- 1 | """ Binary Cross Entropy w/ a few extras 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | from typing import Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BinaryCrossEntropy(nn.Module): 13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding 14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove 15 | """ 16 | def __init__( 17 | self, 18 | smoothing=0.1, 19 | target_threshold: Optional[float] = None, 20 | weight: Optional[torch.Tensor] = None, 21 | reduction: str = 'mean', 22 | sum_classes: bool = False, 23 | pos_weight: Optional[Union[torch.Tensor, float]] = None, 24 | ): 25 | super(BinaryCrossEntropy, self).__init__() 26 | assert 0. <= smoothing < 1.0 27 | if pos_weight is not None: 28 | if not isinstance(pos_weight, torch.Tensor): 29 | pos_weight = torch.tensor(pos_weight) 30 | self.smoothing = smoothing 31 | self.target_threshold = target_threshold 32 | self.reduction = 'none' if sum_classes else reduction 33 | self.sum_classes = sum_classes 34 | self.register_buffer('weight', weight) 35 | self.register_buffer('pos_weight', pos_weight) 36 | 37 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 38 | batch_size = x.shape[0] 39 | assert batch_size == target.shape[0] 40 | 41 | if target.shape != x.shape: 42 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse 43 | num_classes = x.shape[-1] 44 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ 45 | off_value = self.smoothing / num_classes 46 | on_value = 1. - self.smoothing + off_value 47 | target = target.long().view(-1, 1) 48 | target = torch.full( 49 | (batch_size, num_classes), 50 | off_value, 51 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value) 52 | 53 | if self.target_threshold is not None: 54 | # Make target 0, or 1 if threshold set 55 | target = target.gt(self.target_threshold).to(dtype=target.dtype) 56 | 57 | loss = F.binary_cross_entropy_with_logits( 58 | x, target, 59 | self.weight, 60 | pos_weight=self.pos_weight, 61 | reduction=self.reduction, 62 | ) 63 | if self.sum_classes: 64 | loss = loss.sum(-1).mean() 65 | return loss 66 | -------------------------------------------------------------------------------- /timm/optim/sgdp.py: -------------------------------------------------------------------------------- 1 | """ 2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py 3 | 4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 5 | Code: https://github.com/clovaai/AdamP 6 | 7 | Copyright (c) 2020-present NAVER Corp. 8 | MIT license 9 | """ 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.optim.optimizer import Optimizer, required 14 | import math 15 | 16 | from .adamp import projection 17 | 18 | 19 | class SGDP(Optimizer): 20 | def __init__( 21 | self, 22 | params, 23 | lr=required, 24 | momentum=0, 25 | dampening=0, 26 | weight_decay=0, 27 | nesterov=False, 28 | eps=1e-8, 29 | delta=0.1, 30 | wd_ratio=0.1 31 | ): 32 | defaults = dict( 33 | lr=lr, 34 | momentum=momentum, 35 | dampening=dampening, 36 | weight_decay=weight_decay, 37 | nesterov=nesterov, 38 | eps=eps, 39 | delta=delta, 40 | wd_ratio=wd_ratio, 41 | ) 42 | super(SGDP, self).__init__(params, defaults) 43 | 44 | @torch.no_grad() 45 | def step(self, closure=None): 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | weight_decay = group['weight_decay'] 53 | momentum = group['momentum'] 54 | dampening = group['dampening'] 55 | nesterov = group['nesterov'] 56 | 57 | for p in group['params']: 58 | if p.grad is None: 59 | continue 60 | grad = p.grad 61 | state = self.state[p] 62 | 63 | # State initialization 64 | if len(state) == 0: 65 | state['momentum'] = torch.zeros_like(p) 66 | 67 | # SGD 68 | buf = state['momentum'] 69 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 70 | if nesterov: 71 | d_p = grad + momentum * buf 72 | else: 73 | d_p = buf 74 | 75 | # Projection 76 | wd_ratio = 1. 77 | if len(p.shape) > 1: 78 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) 79 | 80 | # Weight decay 81 | if weight_decay != 0: 82 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) 83 | 84 | # Step 85 | p.add_(d_p, alpha=-group['lr']) 86 | 87 | return loss 88 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | from collections import OrderedDict 8 | from typing import Callable, Dict 9 | 10 | import torch 11 | from torch.optim.optimizer import Optimizer 12 | from collections import defaultdict 13 | 14 | 15 | class Lookahead(Optimizer): 16 | def __init__(self, base_optimizer, alpha=0.5, k=6): 17 | # NOTE super().__init__() not called on purpose 18 | self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict() 19 | self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict() 20 | if not 0.0 <= alpha <= 1.0: 21 | raise ValueError(f'Invalid slow update rate: {alpha}') 22 | if not 1 <= k: 23 | raise ValueError(f'Invalid lookahead steps: {k}') 24 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 25 | self._base_optimizer = base_optimizer 26 | self.param_groups = base_optimizer.param_groups 27 | self.defaults = base_optimizer.defaults 28 | self.defaults.update(defaults) 29 | self.state = defaultdict(dict) 30 | # manually add our defaults to the param groups 31 | for name, default in defaults.items(): 32 | for group in self._base_optimizer.param_groups: 33 | group.setdefault(name, default) 34 | 35 | @torch.no_grad() 36 | def update_slow(self, group): 37 | for fast_p in group["params"]: 38 | if fast_p.grad is None: 39 | continue 40 | param_state = self._base_optimizer.state[fast_p] 41 | if 'lookahead_slow_buff' not in param_state: 42 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p) 43 | param_state['lookahead_slow_buff'].copy_(fast_p) 44 | slow = param_state['lookahead_slow_buff'] 45 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha']) 46 | fast_p.copy_(slow) 47 | 48 | def sync_lookahead(self): 49 | for group in self._base_optimizer.param_groups: 50 | self.update_slow(group) 51 | 52 | @torch.no_grad() 53 | def step(self, closure=None): 54 | loss = self._base_optimizer.step(closure) 55 | for group in self._base_optimizer.param_groups: 56 | group['lookahead_step'] += 1 57 | if group['lookahead_step'] % group['lookahead_k'] == 0: 58 | self.update_slow(group) 59 | return loss 60 | 61 | def state_dict(self): 62 | return self._base_optimizer.state_dict() 63 | 64 | def load_state_dict(self, state_dict): 65 | self._base_optimizer.load_state_dict(state_dict) 66 | self.param_groups = self._base_optimizer.param_groups 67 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_tar.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that reads single tarfile based datasets 2 | 3 | This reader can read datasets consisting if a single tarfile containing images. 4 | I am planning to deprecated it in favour of ParerImageInTar. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | import tarfile 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True): 19 | extensions = get_img_extensions(as_set=True) 20 | files = [] 21 | labels = [] 22 | for ti in tarfile.getmembers(): 23 | if not ti.isfile(): 24 | continue 25 | dirname, basename = os.path.split(ti.path) 26 | label = os.path.basename(dirname) 27 | ext = os.path.splitext(basename)[1] 28 | if ext.lower() in extensions: 29 | files.append(ti) 30 | labels.append(label) 31 | if class_to_idx is None: 32 | unique_labels = set(labels) 33 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] 36 | if sort: 37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) 38 | return tarinfo_and_targets, class_to_idx 39 | 40 | 41 | class ReaderImageTar(Reader): 42 | """ Single tarfile dataset where classes are mapped to folders within tar 43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can 44 | operate on folders of tars or tars in tars. 45 | """ 46 | def __init__(self, root, class_map=''): 47 | super().__init__() 48 | 49 | class_to_idx = None 50 | if class_map: 51 | class_to_idx = load_class_map(class_map, root) 52 | assert os.path.isfile(root) 53 | self.root = root 54 | 55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later 56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) 57 | self.imgs = self.samples 58 | self.tarfile = None # lazy init in __getitem__ 59 | 60 | def __getitem__(self, index): 61 | if self.tarfile is None: 62 | self.tarfile = tarfile.open(self.root) 63 | tarinfo, target = self.samples[index] 64 | fileobj = self.tarfile.extractfile(tarinfo) 65 | return fileobj, target 66 | 67 | def __len__(self): 68 | return len(self.samples) 69 | 70 | def _filename(self, index, basename=False, absolute=False): 71 | filename = self.samples[index][0].name 72 | if basename: 73 | filename = os.path.basename(filename) 74 | return filename 75 | -------------------------------------------------------------------------------- /timm/layers/pos_embed.py: -------------------------------------------------------------------------------- 1 | """ Position Embedding Utilities 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import logging 6 | import math 7 | from typing import List, Tuple, Optional, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from ._fx import register_notrace_function 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | @torch.fx.wrap 18 | @register_notrace_function 19 | def resample_abs_pos_embed( 20 | posemb: torch.Tensor, 21 | new_size: List[int], 22 | old_size: Optional[List[int]] = None, 23 | num_prefix_tokens: int = 1, 24 | interpolation: str = 'bicubic', 25 | antialias: bool = True, 26 | verbose: bool = False, 27 | ): 28 | # sort out sizes, assume square if old size not provided 29 | num_pos_tokens = posemb.shape[1] 30 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens 31 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: 32 | return posemb 33 | 34 | if old_size is None: 35 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) 36 | old_size = hw, hw 37 | 38 | if num_prefix_tokens: 39 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] 40 | else: 41 | posemb_prefix, posemb = None, posemb 42 | 43 | # do the interpolation 44 | embed_dim = posemb.shape[-1] 45 | orig_dtype = posemb.dtype 46 | posemb = posemb.float() # interpolate needs float32 47 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) 48 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 49 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) 50 | posemb = posemb.to(orig_dtype) 51 | 52 | # add back extra (class, etc) prefix tokens 53 | if posemb_prefix is not None: 54 | posemb = torch.cat([posemb_prefix, posemb], dim=1) 55 | 56 | if not torch.jit.is_scripting() and verbose: 57 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.') 58 | 59 | return posemb 60 | 61 | 62 | @torch.fx.wrap 63 | @register_notrace_function 64 | def resample_abs_pos_embed_nhwc( 65 | posemb: torch.Tensor, 66 | new_size: List[int], 67 | interpolation: str = 'bicubic', 68 | antialias: bool = True, 69 | verbose: bool = False, 70 | ): 71 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]: 72 | return posemb 73 | 74 | orig_dtype = posemb.dtype 75 | posemb = posemb.float() 76 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) 77 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) 78 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) 79 | 80 | if not torch.jit.is_scripting() and verbose: 81 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.') 82 | 83 | return posemb 84 | -------------------------------------------------------------------------------- /timm/task/classification.py: -------------------------------------------------------------------------------- 1 | """Classification training task.""" 2 | import logging 3 | from typing import Callable, Dict, Optional, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .task import TrainingTask 9 | 10 | _logger = logging.getLogger(__name__) 11 | 12 | 13 | class ClassificationTask(TrainingTask): 14 | """Standard supervised classification task. 15 | 16 | Simple task that performs a forward pass through the model and computes 17 | the classification loss. 18 | 19 | Args: 20 | model: The model to train 21 | criterion: Loss function (e.g., CrossEntropyLoss) 22 | device: Device for task tensors/buffers 23 | dtype: Dtype for task tensors/buffers 24 | verbose: Enable info logging 25 | 26 | Example: 27 | >>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda')) 28 | >>> result = task(input, target) 29 | >>> result['loss'].backward() 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model: nn.Module, 35 | criterion: Union[nn.Module, Callable], 36 | device: Optional[torch.device] = None, 37 | dtype: Optional[torch.dtype] = None, 38 | verbose: bool = True, 39 | ): 40 | super().__init__(device=device, dtype=dtype, verbose=verbose) 41 | self.model = model 42 | self.criterion = criterion 43 | 44 | if self.verbose: 45 | loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__ 46 | _logger.info(f"ClassificationTask: criterion={loss_name}") 47 | 48 | def prepare_distributed( 49 | self, 50 | device_ids: Optional[list] = None, 51 | **ddp_kwargs 52 | ) -> 'ClassificationTask': 53 | """Prepare task for distributed training. 54 | 55 | Wraps the model in DistributedDataParallel (DDP). 56 | 57 | Args: 58 | device_ids: List of device IDs for DDP (e.g., [local_rank]) 59 | **ddp_kwargs: Additional arguments passed to DistributedDataParallel 60 | 61 | Returns: 62 | self (for method chaining) 63 | """ 64 | from torch.nn.parallel import DistributedDataParallel as DDP 65 | self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs) 66 | return self 67 | 68 | def forward( 69 | self, 70 | input: torch.Tensor, 71 | target: torch.Tensor, 72 | ) -> Dict[str, torch.Tensor]: 73 | """Forward pass through model and compute classification loss. 74 | 75 | Args: 76 | input: Input tensor [B, C, H, W] 77 | target: Target labels [B] 78 | 79 | Returns: 80 | Dictionary containing: 81 | - 'loss': Classification loss 82 | - 'output': Model logits 83 | """ 84 | output = self.model(input) 85 | loss = self.criterion(output, target) 86 | 87 | return { 88 | 'loss': loss, 89 | 'output': output, 90 | } 91 | -------------------------------------------------------------------------------- /results/generate_csv_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | results = { 6 | 'results-imagenet.csv': [ 7 | 'results-imagenet-real.csv', 8 | 'results-imagenetv2-matched-frequency.csv', 9 | 'results-sketch.csv' 10 | ], 11 | 'results-imagenet-a-clean.csv': [ 12 | 'results-imagenet-a.csv', 13 | ], 14 | 'results-imagenet-r-clean.csv': [ 15 | 'results-imagenet-r.csv', 16 | ], 17 | } 18 | 19 | 20 | def diff(base_df, test_csv): 21 | base_df['mi'] = base_df.model + '-' + base_df.img_size.astype('str') 22 | base_models = base_df['mi'].values 23 | test_df = pd.read_csv(test_csv) 24 | test_df['mi'] = test_df.model + '-' + test_df.img_size.astype('str') 25 | test_models = test_df['mi'].values 26 | 27 | rank_diff = np.zeros_like(test_models, dtype='object') 28 | top1_diff = np.zeros_like(test_models, dtype='object') 29 | top5_diff = np.zeros_like(test_models, dtype='object') 30 | 31 | for rank, model in enumerate(test_models): 32 | if model in base_models: 33 | base_rank = int(np.where(base_models == model)[0]) 34 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank] 35 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank] 36 | 37 | # rank_diff 38 | if rank == base_rank: 39 | rank_diff[rank] = f'0' 40 | elif rank > base_rank: 41 | rank_diff[rank] = f'-{rank - base_rank}' 42 | else: 43 | rank_diff[rank] = f'+{base_rank - rank}' 44 | 45 | # top1_diff 46 | if top1_d >= .0: 47 | top1_diff[rank] = f'+{top1_d:.3f}' 48 | else: 49 | top1_diff[rank] = f'-{abs(top1_d):.3f}' 50 | 51 | # top5_diff 52 | if top5_d >= .0: 53 | top5_diff[rank] = f'+{top5_d:.3f}' 54 | else: 55 | top5_diff[rank] = f'-{abs(top5_d):.3f}' 56 | 57 | else: 58 | rank_diff[rank] = '' 59 | top1_diff[rank] = '' 60 | top5_diff[rank] = '' 61 | 62 | test_df['top1_diff'] = top1_diff 63 | test_df['top5_diff'] = top5_diff 64 | test_df['rank_diff'] = rank_diff 65 | 66 | test_df.drop('mi', axis=1, inplace=True) 67 | base_df.drop('mi', axis=1, inplace=True) 68 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format) 69 | test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 70 | test_df.to_csv(test_csv, index=False, float_format='%.3f') 71 | 72 | 73 | for base_results, test_results in results.items(): 74 | base_df = pd.read_csv(base_results) 75 | base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True) 76 | for test_csv in test_results: 77 | diff(base_df, test_csv) 78 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format) 79 | base_df.to_csv(base_results, index=False, float_format='%.3f') 80 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_a_synsets.txt: -------------------------------------------------------------------------------- 1 | n01498041 2 | n01531178 3 | n01534433 4 | n01558993 5 | n01580077 6 | n01614925 7 | n01616318 8 | n01631663 9 | n01641577 10 | n01669191 11 | n01677366 12 | n01687978 13 | n01694178 14 | n01698640 15 | n01735189 16 | n01770081 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01819313 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01882714 27 | n01910747 28 | n01914609 29 | n01924916 30 | n01944390 31 | n01985128 32 | n01986214 33 | n02007558 34 | n02009912 35 | n02037110 36 | n02051845 37 | n02077923 38 | n02085620 39 | n02099601 40 | n02106550 41 | n02106662 42 | n02110958 43 | n02119022 44 | n02123394 45 | n02127052 46 | n02129165 47 | n02133161 48 | n02137549 49 | n02165456 50 | n02174001 51 | n02177972 52 | n02190166 53 | n02206856 54 | n02219486 55 | n02226429 56 | n02231487 57 | n02233338 58 | n02236044 59 | n02259212 60 | n02268443 61 | n02279972 62 | n02280649 63 | n02281787 64 | n02317335 65 | n02325366 66 | n02346627 67 | n02356798 68 | n02361337 69 | n02410509 70 | n02445715 71 | n02454379 72 | n02486410 73 | n02492035 74 | n02504458 75 | n02655020 76 | n02669723 77 | n02672831 78 | n02676566 79 | n02690373 80 | n02701002 81 | n02730930 82 | n02777292 83 | n02782093 84 | n02787622 85 | n02793495 86 | n02797295 87 | n02802426 88 | n02814860 89 | n02815834 90 | n02837789 91 | n02879718 92 | n02883205 93 | n02895154 94 | n02906734 95 | n02948072 96 | n02951358 97 | n02980441 98 | n02992211 99 | n02999410 100 | n03014705 101 | n03026506 102 | n03124043 103 | n03125729 104 | n03187595 105 | n03196217 106 | n03223299 107 | n03250847 108 | n03255030 109 | n03291819 110 | n03325584 111 | n03355925 112 | n03384352 113 | n03388043 114 | n03417042 115 | n03443371 116 | n03444034 117 | n03445924 118 | n03452741 119 | n03483316 120 | n03584829 121 | n03590841 122 | n03594945 123 | n03617480 124 | n03666591 125 | n03670208 126 | n03717622 127 | n03720891 128 | n03721384 129 | n03724870 130 | n03775071 131 | n03788195 132 | n03804744 133 | n03837869 134 | n03840681 135 | n03854065 136 | n03888257 137 | n03891332 138 | n03935335 139 | n03982430 140 | n04019541 141 | n04033901 142 | n04039381 143 | n04067472 144 | n04086273 145 | n04099969 146 | n04118538 147 | n04131690 148 | n04133789 149 | n04141076 150 | n04146614 151 | n04147183 152 | n04179913 153 | n04208210 154 | n04235860 155 | n04252077 156 | n04252225 157 | n04254120 158 | n04270147 159 | n04275548 160 | n04310018 161 | n04317175 162 | n04344873 163 | n04347754 164 | n04355338 165 | n04366367 166 | n04376876 167 | n04389033 168 | n04399382 169 | n04442312 170 | n04456115 171 | n04482393 172 | n04507155 173 | n04509417 174 | n04532670 175 | n04540053 176 | n04554684 177 | n04562935 178 | n04591713 179 | n04606251 180 | n07583066 181 | n07695742 182 | n07697313 183 | n07697537 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07749582 189 | n07753592 190 | n07760859 191 | n07768694 192 | n07831146 193 | n09229709 194 | n09246464 195 | n09472597 196 | n09835506 197 | n11879895 198 | n12057211 199 | n12144580 200 | n12267677 201 | -------------------------------------------------------------------------------- /timm/data/_info/imagenet_r_synsets.txt: -------------------------------------------------------------------------------- 1 | n01443537 2 | n01484850 3 | n01494475 4 | n01498041 5 | n01514859 6 | n01518878 7 | n01531178 8 | n01534433 9 | n01614925 10 | n01616318 11 | n01630670 12 | n01632777 13 | n01644373 14 | n01677366 15 | n01694178 16 | n01748264 17 | n01770393 18 | n01774750 19 | n01784675 20 | n01806143 21 | n01820546 22 | n01833805 23 | n01843383 24 | n01847000 25 | n01855672 26 | n01860187 27 | n01882714 28 | n01910747 29 | n01944390 30 | n01983481 31 | n01986214 32 | n02007558 33 | n02009912 34 | n02051845 35 | n02056570 36 | n02066245 37 | n02071294 38 | n02077923 39 | n02085620 40 | n02086240 41 | n02088094 42 | n02088238 43 | n02088364 44 | n02088466 45 | n02091032 46 | n02091134 47 | n02092339 48 | n02094433 49 | n02096585 50 | n02097298 51 | n02098286 52 | n02099601 53 | n02099712 54 | n02102318 55 | n02106030 56 | n02106166 57 | n02106550 58 | n02106662 59 | n02108089 60 | n02108915 61 | n02109525 62 | n02110185 63 | n02110341 64 | n02110958 65 | n02112018 66 | n02112137 67 | n02113023 68 | n02113624 69 | n02113799 70 | n02114367 71 | n02117135 72 | n02119022 73 | n02123045 74 | n02128385 75 | n02128757 76 | n02129165 77 | n02129604 78 | n02130308 79 | n02134084 80 | n02138441 81 | n02165456 82 | n02190166 83 | n02206856 84 | n02219486 85 | n02226429 86 | n02233338 87 | n02236044 88 | n02268443 89 | n02279972 90 | n02317335 91 | n02325366 92 | n02346627 93 | n02356798 94 | n02363005 95 | n02364673 96 | n02391049 97 | n02395406 98 | n02398521 99 | n02410509 100 | n02423022 101 | n02437616 102 | n02445715 103 | n02447366 104 | n02480495 105 | n02480855 106 | n02481823 107 | n02483362 108 | n02486410 109 | n02510455 110 | n02526121 111 | n02607072 112 | n02655020 113 | n02672831 114 | n02701002 115 | n02749479 116 | n02769748 117 | n02793495 118 | n02797295 119 | n02802426 120 | n02808440 121 | n02814860 122 | n02823750 123 | n02841315 124 | n02843684 125 | n02883205 126 | n02906734 127 | n02909870 128 | n02939185 129 | n02948072 130 | n02950826 131 | n02951358 132 | n02966193 133 | n02980441 134 | n02992529 135 | n03124170 136 | n03272010 137 | n03345487 138 | n03372029 139 | n03424325 140 | n03452741 141 | n03467068 142 | n03481172 143 | n03494278 144 | n03495258 145 | n03498962 146 | n03594945 147 | n03602883 148 | n03630383 149 | n03649909 150 | n03676483 151 | n03710193 152 | n03773504 153 | n03775071 154 | n03888257 155 | n03930630 156 | n03947888 157 | n04086273 158 | n04118538 159 | n04133789 160 | n04141076 161 | n04146614 162 | n04147183 163 | n04192698 164 | n04254680 165 | n04266014 166 | n04275548 167 | n04310018 168 | n04325704 169 | n04347754 170 | n04389033 171 | n04409515 172 | n04465501 173 | n04487394 174 | n04522168 175 | n04536866 176 | n04552348 177 | n04591713 178 | n07614500 179 | n07693725 180 | n07695742 181 | n07697313 182 | n07697537 183 | n07714571 184 | n07714990 185 | n07718472 186 | n07720875 187 | n07734744 188 | n07742313 189 | n07745940 190 | n07749582 191 | n07753275 192 | n07753592 193 | n07768694 194 | n07873807 195 | n07880968 196 | n07920052 197 | n09472597 198 | n09835506 199 | n10565667 200 | n12267677 201 | -------------------------------------------------------------------------------- /timm/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from typing import Optional, Tuple, Type, Union 11 | 12 | from torch import nn as nn 13 | import torch.nn.functional as F 14 | 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | from .mlp import ConvMlp 18 | from .norm import LayerNorm2d 19 | 20 | 21 | class GlobalContext(nn.Module): 22 | 23 | def __init__( 24 | self, 25 | channels: int, 26 | use_attn: bool = True, 27 | fuse_add: bool = False, 28 | fuse_scale: bool = True, 29 | init_last_zero: bool = False, 30 | rd_ratio: float = 1./8, 31 | rd_channels: Optional[int] = None, 32 | rd_divisor: int = 1, 33 | act_layer: Type[nn.Module] = nn.ReLU, 34 | gate_layer: Union[str, Type[nn.Module]] = 'sigmoid', 35 | device=None, 36 | dtype=None 37 | ): 38 | dd = {'device': device, 'dtype': dtype} 39 | super().__init__() 40 | act_layer = get_act_layer(act_layer) 41 | 42 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True, **dd) if use_attn else None 43 | 44 | if rd_channels is None: 45 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 46 | if fuse_add: 47 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd) 48 | else: 49 | self.mlp_add = None 50 | if fuse_scale: 51 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd) 52 | else: 53 | self.mlp_scale = None 54 | 55 | self.gate = create_act_layer(gate_layer) 56 | self.init_last_zero = init_last_zero 57 | 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | if self.conv_attn is not None: 62 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 63 | if self.mlp_add is not None: 64 | nn.init.zeros_(self.mlp_add.fc2.weight) 65 | 66 | def forward(self, x): 67 | B, C, H, W = x.shape 68 | 69 | if self.conv_attn is not None: 70 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 71 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 72 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 73 | context = context.view(B, C, 1, 1) 74 | else: 75 | context = x.mean(dim=(2, 3), keepdim=True) 76 | 77 | if self.mlp_scale is not None: 78 | mlp_x = self.mlp_scale(context) 79 | x = x * self.gate(mlp_x) 80 | if self.mlp_add is not None: 81 | mlp_x = self.mlp_add(context) 82 | x = x + mlp_x 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /timm/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | from typing import Optional, Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from .create_act import create_act_layer 13 | from .trace_utils import _assert 14 | 15 | 16 | def inv_instance_rms(x, eps: float = 1e-5): 17 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 18 | return rms.expand(x.shape) 19 | 20 | 21 | class FilterResponseNormTlu2d(nn.Module): 22 | def __init__( 23 | self, 24 | num_features: int, 25 | apply_act: bool = True, 26 | eps: float = 1e-5, 27 | rms: bool = True, 28 | device=None, 29 | dtype=None, 30 | **_, 31 | ): 32 | dd = {'device': device, 'dtype': dtype} 33 | super().__init__() 34 | self.apply_act = apply_act # apply activation (non-linearity) 35 | self.rms = rms 36 | self.eps = eps 37 | self.weight = nn.Parameter(torch.empty(num_features, **dd)) 38 | self.bias = nn.Parameter(torch.empty(num_features, **dd)) 39 | self.tau = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | nn.init.ones_(self.weight) 45 | nn.init.zeros_(self.bias) 46 | if self.tau is not None: 47 | nn.init.zeros_(self.tau) 48 | 49 | def forward(self, x): 50 | _assert(x.dim() == 4, 'expected 4D input') 51 | x_dtype = x.dtype 52 | v_shape = (1, -1, 1, 1) 53 | x = x * inv_instance_rms(x, self.eps) 54 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 55 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 56 | 57 | 58 | class FilterResponseNormAct2d(nn.Module): 59 | def __init__( 60 | self, 61 | num_features: int, 62 | apply_act: bool = True, 63 | act_layer: Type[nn.Module] = nn.ReLU, 64 | inplace: Optional[bool] = None, 65 | rms: bool = True, 66 | eps: float = 1e-5, 67 | device=None, 68 | dtype=None, 69 | **_, 70 | ): 71 | dd = {'device': device, 'dtype': dtype} 72 | super().__init__() 73 | if act_layer is not None and apply_act: 74 | self.act = create_act_layer(act_layer, inplace=inplace) 75 | else: 76 | self.act = nn.Identity() 77 | self.rms = rms 78 | self.eps = eps 79 | self.weight = nn.Parameter(torch.empty(num_features, **dd)) 80 | self.bias = nn.Parameter(torch.empty(num_features, **dd)) 81 | 82 | self.reset_parameters() 83 | 84 | def reset_parameters(self): 85 | nn.init.ones_(self.weight) 86 | nn.init.zeros_(self.bias) 87 | 88 | def forward(self, x): 89 | _assert(x.dim() == 4, 'expected 4D input') 90 | x_dtype = x.dtype 91 | v_shape = (1, -1, 1, 1) 92 | x = x * inv_instance_rms(x, self.eps) 93 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 94 | return self.act(x) 95 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition 2 | from timm.layers.activations import * 3 | from timm.layers.adaptive_avgmax_pool import \ 4 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 5 | from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding 6 | from timm.layers.blur_pool import BlurPool2d 7 | from timm.layers.classifier import ClassifierHead, create_classifier 8 | from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer 9 | from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 10 | set_layer_config 11 | from timm.layers.conv2d_same import Conv2dSame, conv2d_same 12 | from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 13 | from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn 14 | from timm.layers.create_attn import get_attn, create_attn 15 | from timm.layers.create_conv2d import create_conv2d 16 | from timm.layers.create_norm import get_norm_layer, create_norm_layer 17 | from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 18 | from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path 19 | from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 20 | from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 22 | from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 23 | from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 24 | from timm.layers.gather_excite import GatherExcite 25 | from timm.layers.global_context import GlobalContext 26 | from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 27 | from timm.layers.inplace_abn import InplaceAbn 28 | from timm.layers.linear import Linear 29 | from timm.layers.mixed_conv2d import MixedConv2d 30 | from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp 31 | from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn 32 | from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 33 | from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 34 | from timm.layers.padding import get_padding, get_same_padding, pad_same 35 | from timm.layers.patch_embed import PatchEmbed 36 | from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d 37 | from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 38 | from timm.layers.selective_kernel import SelectiveKernel 39 | from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct 40 | from timm.layers.split_attn import SplitAttn 41 | from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 42 | from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 43 | from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool 44 | from timm.layers.trace_utils import _assert, _float_to_int 45 | from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 46 | 47 | import warnings 48 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) 49 | -------------------------------------------------------------------------------- /timm/utils/attention_extract.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import re 3 | from collections import OrderedDict 4 | from typing import Union, Optional, List 5 | 6 | import torch 7 | 8 | 9 | class AttentionExtract(torch.nn.Module): 10 | # defaults should cover a significant number of timm models with attention maps. 11 | default_node_names = ['*attn.softmax'] 12 | default_module_names = ['*attn_drop'] 13 | 14 | def __init__( 15 | self, 16 | model: Union[torch.nn.Module], 17 | names: Optional[List[str]] = None, 18 | mode: str = 'eval', 19 | method: str = 'fx', 20 | hook_type: str = 'forward', 21 | use_regex: bool = False, 22 | ): 23 | """ Extract attention maps (or other activations) from a model by name. 24 | 25 | Args: 26 | model: Instantiated model to extract from. 27 | names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks. 28 | mode: 'train' or 'eval' model mode. 29 | method: 'fx' or 'hook' extraction method. 30 | hook_type: 'forward' or 'forward_pre' hooks used. 31 | use_regex: Use regex instead of fnmatch 32 | """ 33 | super().__init__() 34 | assert mode in ('train', 'eval') 35 | if mode == 'train': 36 | model = model.train() 37 | else: 38 | model = model.eval() 39 | 40 | assert method in ('fx', 'hook') 41 | if method == 'fx': 42 | # names are activation node names 43 | from timm.models._features_fx import get_graph_node_names, GraphExtractNet 44 | 45 | node_names = get_graph_node_names(model)[0 if mode == 'train' else 1] 46 | names = names or self.default_node_names 47 | if use_regex: 48 | regexes = [re.compile(r) for r in names] 49 | matched = [g for g in node_names if any([r.match(g) for r in regexes])] 50 | else: 51 | matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])] 52 | if not matched: 53 | raise RuntimeError(f'No node names found matching {names}.') 54 | 55 | self.model = GraphExtractNet(model, matched, return_dict=True) 56 | self.hooks = None 57 | else: 58 | # names are module names 59 | assert hook_type in ('forward', 'forward_pre') 60 | from timm.models._features import FeatureHooks 61 | 62 | module_names = [n for n, m in model.named_modules()] 63 | names = names or self.default_module_names 64 | if use_regex: 65 | regexes = [re.compile(r) for r in names] 66 | matched = [m for m in module_names if any([r.match(m) for r in regexes])] 67 | else: 68 | matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])] 69 | if not matched: 70 | raise RuntimeError(f'No module names found matching {names}.') 71 | 72 | self.model = model 73 | self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type) 74 | 75 | self.names = matched 76 | self.mode = mode 77 | self.method = method 78 | 79 | def forward(self, x): 80 | if self.hooks is not None: 81 | self.model(x) 82 | output = self.hooks.get_output(device=x.device) 83 | else: 84 | output = self.model(x) 85 | return output 86 | -------------------------------------------------------------------------------- /timm/layers/patch_dropout.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def patch_dropout_forward( 8 | x: torch.Tensor, 9 | prob: float, 10 | num_prefix_tokens: int, 11 | ordered: bool, 12 | training: bool, 13 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 14 | """ 15 | Common forward logic for patch dropout. 16 | 17 | Args: 18 | x: Input tensor of shape (B, L, D) 19 | prob: Dropout probability 20 | num_prefix_tokens: Number of prefix tokens to preserve 21 | ordered: Whether to maintain patch order 22 | training: Whether in training mode 23 | 24 | Returns: 25 | Tuple of (output tensor, keep_indices or None) 26 | """ 27 | if not training or prob == 0.: 28 | return x, None 29 | 30 | if num_prefix_tokens: 31 | prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:] 32 | else: 33 | prefix_tokens = None 34 | 35 | B = x.shape[0] 36 | L = x.shape[1] 37 | num_keep = max(1, int(L * (1. - prob))) 38 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] 39 | 40 | if ordered: 41 | # NOTE does not need to maintain patch order in typical transformer use, 42 | # but possibly useful for debug / visualization 43 | keep_indices = keep_indices.sort(dim=-1)[0] 44 | 45 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) 46 | 47 | if prefix_tokens is not None: 48 | x = torch.cat((prefix_tokens, x), dim=1) 49 | 50 | return x, keep_indices 51 | 52 | 53 | class PatchDropout(nn.Module): 54 | """ 55 | Patch Dropout without returning indices. 56 | https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 57 | """ 58 | 59 | def __init__( 60 | self, 61 | prob: float = 0.5, 62 | num_prefix_tokens: int = 1, 63 | ordered: bool = False, 64 | ): 65 | super().__init__() 66 | assert 0 <= prob < 1. 67 | self.prob = prob 68 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) 69 | self.ordered = ordered 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | output, _ = patch_dropout_forward( 73 | x, 74 | self.prob, 75 | self.num_prefix_tokens, 76 | self.ordered, 77 | self.training 78 | ) 79 | return output 80 | 81 | 82 | class PatchDropoutWithIndices(nn.Module): 83 | """ 84 | Patch Dropout that returns both output and keep indices. 85 | https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 86 | """ 87 | 88 | def __init__( 89 | self, 90 | prob: float = 0.5, 91 | num_prefix_tokens: int = 1, 92 | ordered: bool = False, 93 | ): 94 | super().__init__() 95 | assert 0 <= prob < 1. 96 | self.prob = prob 97 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) 98 | self.ordered = ordered 99 | 100 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 101 | return patch_dropout_forward( 102 | x, 103 | self.prob, 104 | self.num_prefix_tokens, 105 | self.ordered, 106 | self.training 107 | ) 108 | -------------------------------------------------------------------------------- /timm/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from typing import Any, Dict, Optional, Type 6 | 7 | from torch import nn as nn 8 | 9 | from .typing import LayerType, PadType 10 | from .blur_pool import create_aa 11 | from .create_conv2d import create_conv2d 12 | from .create_norm_act import get_norm_act_layer 13 | 14 | 15 | class ConvNormAct(nn.Module): 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: int, 20 | kernel_size: int = 1, 21 | stride: int = 1, 22 | padding: PadType = '', 23 | dilation: int = 1, 24 | groups: int = 1, 25 | bias: bool = False, 26 | apply_norm: bool = True, 27 | apply_act: bool = True, 28 | norm_layer: LayerType = nn.BatchNorm2d, 29 | act_layer: Optional[LayerType] = nn.ReLU, 30 | aa_layer: Optional[LayerType] = None, 31 | drop_layer: Optional[Type[nn.Module]] = None, 32 | conv_kwargs: Optional[Dict[str, Any]] = None, 33 | norm_kwargs: Optional[Dict[str, Any]] = None, 34 | act_kwargs: Optional[Dict[str, Any]] = None, 35 | device=None, 36 | dtype=None, 37 | ): 38 | dd = {'device': device, 'dtype': dtype} 39 | super().__init__() 40 | conv_kwargs = {**dd, **(conv_kwargs or {})} 41 | norm_kwargs = {**dd, **(norm_kwargs or {})} 42 | act_kwargs = act_kwargs or {} 43 | use_aa = aa_layer is not None and stride > 1 44 | 45 | self.conv = create_conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size, 49 | stride=1 if use_aa else stride, 50 | padding=padding, 51 | dilation=dilation, 52 | groups=groups, 53 | bias=bias, 54 | **conv_kwargs, 55 | ) 56 | 57 | if apply_norm: 58 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 59 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 60 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 61 | if drop_layer: 62 | norm_kwargs['drop_layer'] = drop_layer 63 | self.bn = norm_act_layer( 64 | out_channels, 65 | apply_act=apply_act, 66 | act_kwargs=act_kwargs, 67 | **norm_kwargs, 68 | ) 69 | else: 70 | self.bn = nn.Sequential() 71 | if drop_layer: 72 | norm_kwargs['drop_layer'] = drop_layer 73 | self.bn.add_module('drop', drop_layer()) 74 | 75 | self.aa = create_aa( 76 | aa_layer, 77 | out_channels, 78 | stride=stride, 79 | enable=use_aa, 80 | noop=None, 81 | **dd, 82 | ) 83 | 84 | @property 85 | def in_channels(self): 86 | return self.conv.in_channels 87 | 88 | @property 89 | def out_channels(self): 90 | return self.conv.out_channels 91 | 92 | def forward(self, x): 93 | x = self.conv(x) 94 | x = self.bn(x) 95 | aa = getattr(self, 'aa', None) 96 | if aa is not None: 97 | x = self.aa(x) 98 | return x 99 | 100 | 101 | ConvBnAct = ConvNormAct 102 | ConvNormActAa = ConvNormAct # backwards compat, when they were separate 103 | -------------------------------------------------------------------------------- /hfdocs/source/hparams.mdx: -------------------------------------------------------------------------------- 1 | # HParams 2 | Over the years, many `timm` models have been trained with various hyper-parameters as the libraries and models evolved. I don't have a record of every instance, but have recorded instances of many that can serve as a very good starting point. 3 | 4 | ## Tags 5 | Most `timm` trained models have an identifier in their pretrained tag that relates them (roughly) to a family / version of hparams I've used over the years. 6 | 7 | | Tag(s) | Description | Optimizer | LR Schedule | Other Notes | 8 | |--------|-------------|-----------|-------------|-------------| 9 | | `a1h` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `A1` recipe | LAMB | Cosine with warmup | Stronger dropout, stochastic depth, and RandAugment than paper `A1` recipe | 10 | | `ah` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `A1` recipe | LAMB | Cosine with warmup | No CutMix. Stronger dropout, stochastic depth, and RandAugment than paper `A1` recipe | 11 | | `a1`, `a2`, `a3` | ResNet Strikes Back `A{1,2,3}` recipe | LAMB with BCE loss | Cosine with warmup | — | 12 | | `b1`, `b2`, `b1k`, `b2k` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `B` recipe (equivalent to `timm` `RA2` recipes) | RMSProp (TF 1.0 behaviour) | Step (exponential decay w/ staircase) with warmup | — | 13 | | `c`, `c1`, `c2`, `c3` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `C` recipes | SGD (Nesterov) with AGC | Cosine with warmup | — | 14 | | `ch` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `C` recipes | SGD (Nesterov) with AGC | Cosine with warmup | Stronger dropout, stochastic depth, and RandAugment than paper `C1`/`C2` recipes | 15 | | `d`, `d1`, `d2` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `D` recipe | AdamW with BCE loss | Cosine with warmup | — | 16 | | `sw` | Based on Swin Transformer train/pretrain recipe (basis of DeiT and ConvNeXt recipes) | AdamW with gradient clipping, EMA | Cosine with warmup | — | 17 | | `ra`, `ra2`, `ra3`, `racm`, `raa` | RandAugment recipes. Inspired by EfficientNet RandAugment recipes. Covered by `B` recipe in [ResNet Strikes Back](https://arxiv.org/abs/2110.00476). | RMSProp (TF 1.0 behaviour), EMA | Step (exponential decay w/ staircase) with warmup | — | 18 | | `ra4` | RandAugment v4. Inspired by MobileNetV4 hparams. | - | 19 | | `am` | AugMix recipe | SGD (Nesterov) with JSD loss | Cosine with warmup | — | 20 | | `ram` | AugMix (with RandAugment) recipe | SGD (Nesterov) with JSD loss | Cosine with warmup | — | 21 | | `bt` | Bag-of-Tricks recipe | SGD (Nesterov) | Cosine with warmup | — | 22 | 23 | ## Config File Gists 24 | I've collected several of the hparam families in a series of gists. These can be downloaded and used with the `--config hparam.yaml` argument with the `timm` train script. Some adjustment is always required for the LR vs effective global batch size. 25 | 26 | | Tag | Key Model Architectures | Gist Link | 27 | |-----|------------------------|-----------| 28 | | `ra2` | ResNet, EfficientNet, RegNet, NFNet | [Link](https://gist.github.com/rwightman/07839a82d0f50e42840168bc43df70b3) | 29 | | `ra3` | RegNet | [Link](https://gist.github.com/rwightman/37252f8d7d850a94e43f1fcb7b3b8322) | 30 | | `ra4` | MobileNetV4 | [Link](https://gist.github.com/rwightman/f6705cb65c03daeebca8aa129b1b94ad) | 31 | | `sw` | ViT, ConvNeXt, CoAtNet, MaxViT | [Link](https://gist.github.com/rwightman/943c0fe59293b44024bbd2d5d23e6303) | 32 | | `sbb` | ViT | [Link](https://gist.github.com/rwightman/fb37c339efd2334177ff99a8083ebbc4) | 33 | | — | Tiny Test Models | [Link](https://gist.github.com/rwightman/9ba8efc39a546426e99055720d2f705f) | 34 | -------------------------------------------------------------------------------- /timm/loss/asymmetric_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLossMultiLabel(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 7 | super(AsymmetricLossMultiLabel, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossSingleLabel(nn.Module): 54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): 55 | super(AsymmetricLossSingleLabel, self).__init__() 56 | 57 | self.eps = eps 58 | self.logsoftmax = nn.LogSoftmax(dim=-1) 59 | self.targets_classes = [] # prevent gpu repeated memory allocation 60 | self.gamma_pos = gamma_pos 61 | self.gamma_neg = gamma_neg 62 | self.reduction = reduction 63 | 64 | def forward(self, inputs, target, reduction=None): 65 | """" 66 | Parameters 67 | ---------- 68 | x: input logits 69 | y: targets (1-hot vector) 70 | """ 71 | 72 | num_classes = inputs.size()[-1] 73 | log_preds = self.logsoftmax(inputs) 74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 75 | 76 | # ASL weights 77 | targets = self.targets_classes 78 | anti_targets = 1 - targets 79 | xs_pos = torch.exp(log_preds) 80 | xs_neg = 1 - xs_pos 81 | xs_pos = xs_pos * targets 82 | xs_neg = xs_neg * anti_targets 83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 84 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 85 | log_preds = log_preds * asymmetric_w 86 | 87 | if self.eps > 0: # label smoothing 88 | self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes) 89 | 90 | # loss calculation 91 | loss = - self.targets_classes.mul(log_preds) 92 | 93 | loss = loss.sum(dim=-1) 94 | if self.reduction == 'mean': 95 | loss = loss.mean() 96 | 97 | return loss 98 | -------------------------------------------------------------------------------- /timm/data/readers/reader_hfds.py: -------------------------------------------------------------------------------- 1 | """ Dataset reader that wraps Hugging Face datasets 2 | 3 | Hacked together by / Copyright 2022 Ross Wightman 4 | """ 5 | import io 6 | import math 7 | from typing import Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from PIL import Image 12 | 13 | try: 14 | import datasets 15 | except ImportError as e: 16 | print("Please install Hugging Face datasets package `pip install datasets`.") 17 | raise e 18 | from .class_map import load_class_map 19 | from .reader import Reader 20 | 21 | 22 | def get_class_labels(info, label_key='label'): 23 | if 'label' not in info.features: 24 | return {} 25 | class_label = info.features[label_key] 26 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names} 27 | return class_to_idx 28 | 29 | 30 | class ReaderHfds(Reader): 31 | 32 | def __init__( 33 | self, 34 | name: str, 35 | root: Optional[str] = None, 36 | split: str = 'train', 37 | class_map: dict = None, 38 | input_key: str = 'image', 39 | target_key: str = 'label', 40 | additional_features: Optional[list[str]] = None, 41 | download: bool = False, 42 | trust_remote_code: bool = False 43 | ): 44 | """ 45 | """ 46 | super().__init__() 47 | self.root = root 48 | self.split = split 49 | self.dataset = datasets.load_dataset( 50 | name, # 'name' maps to path arg in hf datasets 51 | split=split, 52 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set 53 | trust_remote_code=trust_remote_code 54 | ) 55 | # leave decode for caller, plus we want easy access to original path names... 56 | self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False)) 57 | 58 | self.image_key = input_key 59 | self.label_key = target_key 60 | self.remap_class = False 61 | if class_map: 62 | self.class_to_idx = load_class_map(class_map) 63 | self.remap_class = True 64 | else: 65 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) 66 | self.split_info = self.dataset.info.splits[split] 67 | self.num_samples = self.split_info.num_examples 68 | 69 | if additional_features is not None: 70 | if isinstance(additional_features, list): 71 | self.additional_features = additional_features 72 | else: 73 | self.additional_features = [additional_features] 74 | else: 75 | self.additional_features = None 76 | 77 | def __getitem__(self, index): 78 | item = self.dataset[index] 79 | image = item[self.image_key] 80 | 81 | if 'bytes' in image and image['bytes']: 82 | image = io.BytesIO(image['bytes']) 83 | else: 84 | assert 'path' in image and image['path'] 85 | image = open(image['path'], 'rb') 86 | 87 | label = item[self.label_key] 88 | if self.remap_class: 89 | label = self.class_to_idx[label] 90 | 91 | if self.additional_features is not None: 92 | features = [item[feat] for feat in self.additional_features] 93 | return image, label, *features 94 | else: 95 | return image, label 96 | 97 | def __len__(self): 98 | return len(self.dataset) 99 | 100 | def _filename(self, index, basename=False, absolute=False): 101 | item = self.dataset[index] 102 | return item[self.image_key]['path'] 103 | -------------------------------------------------------------------------------- /timm/models/_features_fx.py: -------------------------------------------------------------------------------- 1 | """ PyTorch FX Based Feature Extraction Helpers 2 | Using https://pytorch.org/vision/stable/feature_extraction.html 3 | """ 4 | from typing import Callable, Dict, List, Optional, Union, Tuple, Type 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from timm.layers import ( 10 | create_feature_extractor, 11 | get_graph_node_names, 12 | register_notrace_module, 13 | register_notrace_function, 14 | is_notrace_module, 15 | is_notrace_function, 16 | get_notrace_functions, 17 | get_notrace_modules, 18 | Format, 19 | ) 20 | from ._features import _get_feature_info, _get_return_layers 21 | 22 | 23 | 24 | __all__ = [ 25 | 'register_notrace_module', 26 | 'is_notrace_module', 27 | 'get_notrace_modules', 28 | 'register_notrace_function', 29 | 'is_notrace_function', 30 | 'get_notrace_functions', 31 | 'create_feature_extractor', 32 | 'get_graph_node_names', 33 | 'FeatureGraphNet', 34 | 'GraphExtractNet', 35 | ] 36 | 37 | 38 | class FeatureGraphNet(nn.Module): 39 | """ A FX Graph based feature extractor that works with the model feature_info metadata 40 | """ 41 | return_dict: torch.jit.Final[bool] 42 | 43 | def __init__( 44 | self, 45 | model: nn.Module, 46 | out_indices: Tuple[int, ...], 47 | out_map: Optional[Dict] = None, 48 | output_fmt: str = 'NCHW', 49 | return_dict: bool = False, 50 | ): 51 | super().__init__() 52 | self.feature_info = _get_feature_info(model, out_indices) 53 | if out_map is not None: 54 | assert len(out_map) == len(out_indices) 55 | self.output_fmt = Format(output_fmt) 56 | return_nodes = _get_return_layers(self.feature_info, out_map) 57 | self.graph_module = create_feature_extractor(model, return_nodes) 58 | self.return_dict = return_dict 59 | 60 | def forward(self, x): 61 | out = self.graph_module(x) 62 | if self.return_dict: 63 | return out 64 | return list(out.values()) 65 | 66 | 67 | class GraphExtractNet(nn.Module): 68 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor 69 | NOTE: 70 | * one can use feature_extractor directly if dictionary output is desired 71 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info 72 | metadata for builtin feature extraction mode 73 | * create_feature_extractor can be used directly if dictionary output is acceptable 74 | 75 | Args: 76 | model: model to extract features from 77 | return_nodes: node names to return features from (dict or list) 78 | squeeze_out: if only one output, and output in list format, flatten to single tensor 79 | return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg 80 | """ 81 | return_dict: torch.jit.Final[bool] 82 | 83 | def __init__( 84 | self, 85 | model: nn.Module, 86 | return_nodes: Union[Dict[str, str], List[str]], 87 | squeeze_out: bool = True, 88 | return_dict: bool = False, 89 | ): 90 | super().__init__() 91 | self.squeeze_out = squeeze_out 92 | self.graph_module = create_feature_extractor(model, return_nodes) 93 | self.return_dict = return_dict 94 | 95 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: 96 | out = self.graph_module(x) 97 | if self.return_dict: 98 | return out 99 | out = list(out.values()) 100 | return out[0] if self.squeeze_out and len(out) == 1 else out 101 | -------------------------------------------------------------------------------- /timm/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple, Union 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from .helpers import to_2tuple 12 | 13 | 14 | # Calculate symmetric padding for a convolution 15 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]: 16 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): 17 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) 18 | return [get_padding(*a) for a in zip(kernel_size, stride, dilation)] 19 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 20 | return padding 21 | 22 | 23 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 24 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int): 25 | if isinstance(x, torch.Tensor): 26 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0) 27 | else: 28 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0) 29 | 30 | 31 | # Can SAME padding for given args be done statically? 32 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 33 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]): 34 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation) 35 | return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)]) 36 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 37 | 38 | 39 | def pad_same_arg( 40 | input_size: List[int], 41 | kernel_size: List[int], 42 | stride: List[int], 43 | dilation: List[int] = (1, 1), 44 | ) -> List[int]: 45 | ih, iw = input_size 46 | kh, kw = kernel_size 47 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0]) 48 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1]) 49 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] 50 | 51 | 52 | # Dynamically pad input x with 'SAME' padding for conv with specified args 53 | def pad_same( 54 | x, 55 | kernel_size: List[int], 56 | stride: List[int], 57 | dilation: List[int] = (1, 1), 58 | value: float = 0, 59 | ): 60 | ih, iw = x.size()[-2:] 61 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0]) 62 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1]) 63 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value) 64 | return x 65 | 66 | 67 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 68 | dynamic = False 69 | if isinstance(padding, str): 70 | # for any string padding, the padding will be calculated for you, one of three ways 71 | padding = padding.lower() 72 | if padding == 'same': 73 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 74 | if is_static_pad(kernel_size, **kwargs): 75 | # static case, no extra overhead 76 | padding = get_padding(kernel_size, **kwargs) 77 | else: 78 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 79 | padding = 0 80 | dynamic = True 81 | elif padding == 'valid': 82 | # 'VALID' padding, same as padding=0 83 | padding = 0 84 | else: 85 | # Default to PyTorch style 'same'-ish symmetric padding 86 | padding = get_padding(kernel_size, **kwargs) 87 | return padding, dynamic 88 | -------------------------------------------------------------------------------- /timm/models/_pretrained.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import deque, defaultdict 3 | from dataclasses import dataclass, field, replace, asdict 4 | from typing import Any, Deque, Dict, Tuple, Optional, Union 5 | 6 | 7 | __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg'] 8 | 9 | 10 | @dataclass 11 | class PretrainedCfg: 12 | """ 13 | """ 14 | # weight source locations 15 | url: Optional[Union[str, Tuple[str, str]]] = None # remote URL 16 | file: Optional[str] = None # local / shared filesystem path 17 | state_dict: Optional[Dict[str, Any]] = None # in-memory state dict 18 | hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model') 19 | hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default) 20 | 21 | source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) 22 | architecture: Optional[str] = None # architecture variant can be set when not implicit 23 | tag: Optional[str] = None # pretrained tag of source 24 | custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) 25 | 26 | # input / data config 27 | input_size: Tuple[int, int, int] = (3, 224, 224) 28 | test_input_size: Optional[Tuple[int, int, int]] = None 29 | min_input_size: Optional[Tuple[int, int, int]] = None 30 | fixed_input_size: bool = False 31 | interpolation: str = 'bicubic' 32 | crop_pct: float = 0.875 33 | test_crop_pct: Optional[float] = None 34 | crop_mode: str = 'center' 35 | mean: Tuple[float, ...] = (0.485, 0.456, 0.406) 36 | std: Tuple[float, ...] = (0.229, 0.224, 0.225) 37 | 38 | # head / classifier config and meta-data 39 | num_classes: int = 1000 40 | label_offset: Optional[int] = None 41 | label_names: Optional[Tuple[str]] = None 42 | label_descriptions: Optional[Dict[str, str]] = None 43 | 44 | # model attributes that vary with above or required for pretrained adaptation 45 | pool_size: Optional[Tuple[int, ...]] = None 46 | test_pool_size: Optional[Tuple[int, ...]] = None 47 | first_conv: Optional[str] = None 48 | classifier: Optional[str] = None 49 | 50 | license: Optional[str] = None 51 | description: Optional[str] = None 52 | origin_url: Optional[str] = None 53 | paper_name: Optional[str] = None 54 | paper_ids: Optional[Union[str, Tuple[str]]] = None 55 | notes: Optional[Tuple[str]] = None 56 | 57 | @property 58 | def has_weights(self): 59 | return self.url or self.file or self.hf_hub_id 60 | 61 | def to_dict(self, remove_source=False, remove_null=True): 62 | return filter_pretrained_cfg( 63 | asdict(self), 64 | remove_source=remove_source, 65 | remove_null=remove_null 66 | ) 67 | 68 | 69 | def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): 70 | filtered_cfg = {} 71 | keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none 72 | for k, v in cfg.items(): 73 | if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: 74 | continue 75 | if remove_null and v is None and k not in keep_null: 76 | continue 77 | filtered_cfg[k] = v 78 | return filtered_cfg 79 | 80 | 81 | @dataclass 82 | class DefaultCfg: 83 | tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default) 84 | cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag 85 | is_pretrained: bool = False # at least one of the configs has a pretrained source set 86 | 87 | @property 88 | def default(self): 89 | return self.cfgs[self.tags[0]] 90 | 91 | @property 92 | def default_with_tag(self): 93 | tag = self.tags[0] 94 | return tag, self.cfgs[tag] 95 | -------------------------------------------------------------------------------- /timm/task/task.py: -------------------------------------------------------------------------------- 1 | """Base training task abstraction. 2 | 3 | This module provides the base TrainingTask class that encapsulates a complete 4 | forward pass including loss computation. Tasks return a dictionary with loss 5 | components and outputs for logging. 6 | """ 7 | from typing import Dict, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class TrainingTask(nn.Module): 14 | """Base class for training tasks. 15 | 16 | A training task encapsulates a complete forward pass including loss computation. 17 | Tasks return a dictionary containing the training loss and other components for logging. 18 | 19 | The returned dictionary must contain: 20 | - 'loss': The training loss for backward pass (required) 21 | - 'output': Model output/logits for metric computation (recommended) 22 | - Other task-specific loss components for logging (optional) 23 | 24 | Args: 25 | device: Device for task tensors/buffers (defaults to cpu) 26 | dtype: Dtype for task tensors/buffers (defaults to torch default) 27 | verbose: Enable info logging 28 | 29 | Example: 30 | >>> task = SomeTask(model, criterion, device=torch.device('cuda')) 31 | >>> 32 | >>> # Prepare for distributed training (if needed) 33 | >>> if distributed: 34 | >>> task.prepare_distributed(device_ids=[local_rank]) 35 | >>> 36 | >>> # Training loop 37 | >>> result = task(input, target) 38 | >>> result['loss'].backward() 39 | """ 40 | 41 | def __init__( 42 | self, 43 | device: Optional[torch.device] = None, 44 | dtype: Optional[torch.dtype] = None, 45 | verbose: bool = True, 46 | ): 47 | super().__init__() 48 | self.device = device if device is not None else torch.device('cpu') 49 | self.dtype = dtype if dtype is not None else torch.get_default_dtype() 50 | self.verbose = verbose 51 | 52 | def to(self, *args, **kwargs): 53 | """Move task to device/dtype, keeping self.device and self.dtype in sync.""" 54 | dummy = torch.empty(0).to(*args, **kwargs) 55 | self.device = dummy.device 56 | self.dtype = dummy.dtype 57 | return super().to(*args, **kwargs) 58 | 59 | def prepare_distributed( 60 | self, 61 | device_ids: Optional[list] = None, 62 | **ddp_kwargs 63 | ) -> 'TrainingTask': 64 | """Prepare task for distributed training. 65 | 66 | This method wraps trainable components in DistributedDataParallel (DDP) 67 | while leaving non-trainable components (like frozen teacher models) unwrapped. 68 | 69 | Should be called after task initialization but before training loop. 70 | 71 | Args: 72 | device_ids: List of device IDs for DDP (e.g., [local_rank]) 73 | **ddp_kwargs: Additional arguments passed to DistributedDataParallel 74 | 75 | Returns: 76 | self (for method chaining) 77 | 78 | Example: 79 | >>> task = LogitDistillationTask(student, teacher, criterion) 80 | >>> task.prepare_distributed(device_ids=[args.local_rank]) 81 | >>> task = torch.compile(task) # Compile after DDP 82 | """ 83 | # Default implementation - subclasses override if they need DDP 84 | return self 85 | 86 | def forward( 87 | self, 88 | input: torch.Tensor, 89 | target: torch.Tensor, 90 | ) -> Dict[str, torch.Tensor]: 91 | """Perform forward pass and compute loss. 92 | 93 | Args: 94 | input: Input tensor [B, C, H, W] 95 | target: Target labels [B] 96 | 97 | Returns: 98 | Dictionary with at least 'loss' key containing the training loss 99 | """ 100 | raise NotImplementedError 101 | -------------------------------------------------------------------------------- /timm/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__( 21 | self, 22 | num_features, 23 | eps=1e-5, 24 | momentum=0.1, 25 | affine=True, 26 | track_running_stats=True, 27 | num_splits=2, 28 | device=None, 29 | dtype=None, 30 | ): 31 | dd = {'device': device, 'dtype': dtype} 32 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 33 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 34 | self.num_splits = num_splits 35 | self.aux_bn = nn.ModuleList([ 36 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, **dd) 37 | for _ in range(num_splits - 1) 38 | ]) 39 | 40 | def forward(self, input: torch.Tensor): 41 | if self.training: # aux BN only relevant while training 42 | split_size = input.shape[0] // self.num_splits 43 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 44 | split_input = input.split(split_size) 45 | x = [super().forward(split_input[0])] 46 | for i, a in enumerate(self.aux_bn): 47 | x.append(a(split_input[i + 1])) 48 | return torch.cat(x, dim=0) 49 | else: 50 | return super().forward(input) 51 | 52 | 53 | def convert_splitbn_model(module, num_splits=2): 54 | """ 55 | Recursively traverse module and its children to replace all instances of 56 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 57 | Args: 58 | module (torch.nn.Module): input module 59 | num_splits: number of separate batchnorm layers to split input across 60 | Example:: 61 | >>> # model is an instance of torch.nn.Module 62 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 63 | """ 64 | mod = module 65 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 66 | return module 67 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 68 | mod = SplitBatchNorm2d( 69 | module.num_features, module.eps, module.momentum, module.affine, 70 | module.track_running_stats, num_splits=num_splits) 71 | mod.running_mean = module.running_mean 72 | mod.running_var = module.running_var 73 | mod.num_batches_tracked = module.num_batches_tracked 74 | if module.affine: 75 | mod.weight.data = module.weight.data.clone().detach() 76 | mod.bias.data = module.bias.data.clone().detach() 77 | for aux in mod.aux_bn: 78 | aux.running_mean = module.running_mean.clone() 79 | aux.running_var = module.running_var.clone() 80 | aux.num_batches_tracked = module.num_batches_tracked.clone() 81 | if module.affine: 82 | aux.weight.data = module.weight.data.clone().detach() 83 | aux.bias.data = module.bias.data.clone().detach() 84 | for name, child in module.named_children(): 85 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 86 | del module 87 | return mod 88 | -------------------------------------------------------------------------------- /timm/data/readers/reader_image_folder.py: -------------------------------------------------------------------------------- 1 | """ A dataset reader that extracts images from folders 2 | 3 | Folders are scanned recursively to find image files. Labels are based 4 | on the folder hierarchy, just leaf folders by default. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import os 9 | from typing import Dict, List, Optional, Set, Tuple, Union 10 | 11 | from timm.utils.misc import natural_key 12 | 13 | from .class_map import load_class_map 14 | from .img_extensions import get_img_extensions 15 | from .reader import Reader 16 | 17 | 18 | def find_images_and_targets( 19 | folder: str, 20 | types: Optional[Union[List, Tuple, Set]] = None, 21 | class_to_idx: Optional[Dict] = None, 22 | leaf_name_only: bool = True, 23 | sort: bool = True 24 | ): 25 | """ Walk folder recursively to discover images and map them to classes by folder names. 26 | 27 | Args: 28 | folder: root of folder to recursively search 29 | types: types (file extensions) to search for in path 30 | class_to_idx: specify mapping for class (folder name) to class index if set 31 | leaf_name_only: use only leaf-name of folder walk for class names 32 | sort: re-sort found images by name (for consistent ordering) 33 | 34 | Returns: 35 | A list of image and target tuples, class_to_idx mapping 36 | """ 37 | types = get_img_extensions(as_set=True) if not types else set(types) 38 | labels = [] 39 | filenames = [] 40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): 41 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 43 | for f in files: 44 | base, ext = os.path.splitext(f) 45 | if ext.lower() in types: 46 | filenames.append(os.path.join(root, f)) 47 | labels.append(label) 48 | if class_to_idx is None: 49 | # building class index 50 | unique_labels = set(labels) 51 | sorted_labels = list(sorted(unique_labels, key=natural_key)) 52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} 53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] 54 | if sort: 55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 56 | return images_and_targets, class_to_idx 57 | 58 | 59 | class ReaderImageFolder(Reader): 60 | 61 | def __init__( 62 | self, 63 | root, 64 | class_map='', 65 | input_key=None, 66 | ): 67 | super().__init__() 68 | 69 | self.root = root 70 | class_to_idx = None 71 | if class_map: 72 | class_to_idx = load_class_map(class_map, root) 73 | find_types = None 74 | if input_key: 75 | find_types = input_key.split(';') 76 | self.samples, self.class_to_idx = find_images_and_targets( 77 | root, 78 | class_to_idx=class_to_idx, 79 | types=find_types, 80 | ) 81 | if len(self.samples) == 0: 82 | raise RuntimeError( 83 | f'Found 0 images in subfolders of {root}. ' 84 | f'Supported image extensions are {", ".join(get_img_extensions())}') 85 | 86 | def __getitem__(self, index): 87 | path, target = self.samples[index] 88 | return open(path, 'rb'), target 89 | 90 | def __len__(self): 91 | return len(self.samples) 92 | 93 | def _filename(self, index, basename=False, absolute=False): 94 | filename = self.samples[index][0] 95 | if basename: 96 | filename = os.path.basename(filename) 97 | elif not absolute: 98 | filename = os.path.relpath(filename, self.root) 99 | return filename 100 | -------------------------------------------------------------------------------- /timm/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | from ._fx import register_notrace_module 19 | 20 | 21 | @register_notrace_module 22 | class InplaceAbn(nn.Module): 23 | """Activated Batch Normalization 24 | 25 | This gathers a BatchNorm and an activation function in a single module 26 | 27 | Parameters 28 | ---------- 29 | num_features : int 30 | Number of feature channels in the input and output. 31 | eps : float 32 | Small constant to prevent numerical issues. 33 | momentum : float 34 | Momentum factor applied to compute running statistics. 35 | affine : bool 36 | If `True` apply learned scale and shift transformation after normalization. 37 | act_layer : str or nn.Module type 38 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 39 | act_param : float 40 | Negative slope for the `leaky_relu` activation. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | num_features, 46 | eps=1e-5, 47 | momentum=0.1, 48 | affine=True, 49 | apply_act=True, 50 | act_layer="leaky_relu", 51 | act_param=0.01, 52 | drop_layer=None, 53 | ): 54 | super().__init__() 55 | self.num_features = num_features 56 | self.affine = affine 57 | self.eps = eps 58 | self.momentum = momentum 59 | if apply_act: 60 | if isinstance(act_layer, str): 61 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 62 | self.act_name = act_layer if act_layer else 'identity' 63 | else: 64 | # convert act layer passed as type to string 65 | if act_layer == nn.ELU: 66 | self.act_name = 'elu' 67 | elif act_layer == nn.LeakyReLU: 68 | self.act_name = 'leaky_relu' 69 | elif act_layer is None or act_layer == nn.Identity: 70 | self.act_name = 'identity' 71 | else: 72 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 73 | else: 74 | self.act_name = 'identity' 75 | self.act_param = act_param 76 | if self.affine: 77 | self.weight = nn.Parameter(torch.ones(num_features)) 78 | self.bias = nn.Parameter(torch.zeros(num_features)) 79 | else: 80 | self.register_parameter('weight', None) 81 | self.register_parameter('bias', None) 82 | self.register_buffer('running_mean', torch.zeros(num_features)) 83 | self.register_buffer('running_var', torch.ones(num_features)) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | nn.init.constant_(self.running_mean, 0) 88 | nn.init.constant_(self.running_var, 1) 89 | if self.affine: 90 | nn.init.constant_(self.weight, 1) 91 | nn.init.constant_(self.bias, 0) 92 | 93 | def forward(self, x): 94 | output = inplace_abn( 95 | x, self.weight, self.bias, self.running_mean, self.running_var, 96 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 97 | if isinstance(output, tuple): 98 | output = output[0] 99 | return output 100 | -------------------------------------------------------------------------------- /timm/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional, Union 9 | 10 | from ._fx import register_notrace_module 11 | from .helpers import to_2tuple 12 | from .padding import pad_same, get_padding_value 13 | 14 | 15 | def avg_pool2d_same( 16 | x: torch.Tensor, 17 | kernel_size: List[int], 18 | stride: List[int], 19 | padding: List[int] = (0, 0), 20 | ceil_mode: bool = False, 21 | count_include_pad: bool = True, 22 | ): 23 | # FIXME how to deal with count_include_pad vs not for external padding? 24 | x = pad_same(x, kernel_size, stride) 25 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 26 | 27 | 28 | @register_notrace_module 29 | class AvgPool2dSame(nn.AvgPool2d): 30 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 31 | """ 32 | def __init__( 33 | self, 34 | kernel_size: Union[int, Tuple[int, int]], 35 | stride: Optional[Union[int, Tuple[int, int]]] = None, 36 | padding: Union[int, Tuple[int, int], str] = 0, 37 | ceil_mode: bool = False, 38 | count_include_pad: bool = True, 39 | ): 40 | kernel_size = to_2tuple(kernel_size) 41 | stride = to_2tuple(stride) 42 | super().__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 43 | 44 | def forward(self, x): 45 | x = pad_same(x, self.kernel_size, self.stride) 46 | return F.avg_pool2d( 47 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 48 | 49 | 50 | def max_pool2d_same( 51 | x: torch.Tensor, 52 | kernel_size: List[int], 53 | stride: List[int], 54 | padding: List[int] = (0, 0), 55 | dilation: List[int] = (1, 1), 56 | ceil_mode: bool = False, 57 | ): 58 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 59 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 60 | 61 | 62 | @register_notrace_module 63 | class MaxPool2dSame(nn.MaxPool2d): 64 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 65 | """ 66 | def __init__( 67 | self, 68 | kernel_size: Union[int, Tuple[int, int]], 69 | stride: Optional[Union[int, Tuple[int, int]]] = None, 70 | padding: Union[int, Tuple[int, int], str] = 0, 71 | dilation: Union[int, Tuple[int, int]] = 1, 72 | ceil_mode: bool = False, 73 | ): 74 | kernel_size = to_2tuple(kernel_size) 75 | stride = to_2tuple(stride) 76 | dilation = to_2tuple(dilation) 77 | super().__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 78 | 79 | def forward(self, x): 80 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 81 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 82 | 83 | 84 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 85 | stride = stride or kernel_size 86 | padding = kwargs.pop('padding', '') 87 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 88 | if is_dynamic: 89 | if pool_type == 'avg': 90 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 91 | elif pool_type == 'max': 92 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 93 | else: 94 | assert False, f'Unsupported pool type {pool_type}' 95 | else: 96 | if pool_type == 'avg': 97 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 98 | elif pool_type == 'max': 99 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 100 | else: 101 | assert False, f'Unsupported pool type {pool_type}' 102 | -------------------------------------------------------------------------------- /timm/utils/onnx.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List 2 | 3 | import torch 4 | 5 | 6 | def onnx_forward(onnx_file, example_input): 7 | import onnxruntime 8 | 9 | sess_options = onnxruntime.SessionOptions() 10 | session = onnxruntime.InferenceSession(onnx_file, sess_options) 11 | input_name = session.get_inputs()[0].name 12 | output = session.run([], {input_name: example_input.numpy()}) 13 | output = output[0] 14 | return output 15 | 16 | 17 | def onnx_export( 18 | model: torch.nn.Module, 19 | output_file: str, 20 | example_input: Optional[torch.Tensor] = None, 21 | training: bool = False, 22 | verbose: bool = False, 23 | check: bool = True, 24 | check_forward: bool = False, 25 | batch_size: int = 64, 26 | input_size: Tuple[int, int, int] = None, 27 | opset: Optional[int] = None, 28 | dynamic_size: bool = False, 29 | aten_fallback: bool = False, 30 | keep_initializers: Optional[bool] = None, 31 | use_dynamo: bool = False, 32 | input_names: List[str] = None, 33 | output_names: List[str] = None, 34 | ): 35 | import onnx 36 | 37 | if training: 38 | training_mode = torch.onnx.TrainingMode.TRAINING 39 | model.train() 40 | else: 41 | training_mode = torch.onnx.TrainingMode.EVAL 42 | model.eval() 43 | 44 | if example_input is None: 45 | if not input_size: 46 | assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided' 47 | input_size = model.default_cfg.get('input_size') 48 | example_input = torch.randn((batch_size,) + input_size, requires_grad=training) 49 | 50 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means 51 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for 52 | # the input img_size specified in this script. 53 | 54 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to 55 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit 56 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... 57 | with torch.inference_mode(): 58 | original_out = model(example_input) 59 | 60 | input_names = input_names or ["input0"] 61 | output_names = output_names or ["output0"] 62 | 63 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} 64 | if dynamic_size: 65 | dynamic_axes['input0'][2] = 'height' 66 | dynamic_axes['input0'][3] = 'width' 67 | 68 | if aten_fallback: 69 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK 70 | else: 71 | export_type = torch.onnx.OperatorExportTypes.ONNX 72 | 73 | if use_dynamo: 74 | export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size) 75 | export_output = torch.onnx.dynamo_export( 76 | model, 77 | example_input, 78 | export_options=export_options, 79 | ) 80 | export_output.save(output_file) 81 | else: 82 | torch.onnx.export( 83 | model, 84 | example_input, 85 | output_file, 86 | training=training_mode, 87 | export_params=True, 88 | verbose=verbose, 89 | input_names=input_names, 90 | output_names=output_names, 91 | keep_initializers_as_inputs=keep_initializers, 92 | dynamic_axes=dynamic_axes, 93 | opset_version=opset, 94 | operator_export_type=export_type 95 | ) 96 | 97 | if check: 98 | onnx_model = onnx.load(output_file) 99 | onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error 100 | if check_forward and not training: 101 | import numpy as np 102 | onnx_out = onnx_forward(output_file, example_input) 103 | np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3) 104 | 105 | -------------------------------------------------------------------------------- /timm/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from typing import Optional, Type, Union 9 | 10 | from torch import nn as nn 11 | 12 | from .create_conv2d import create_conv2d 13 | from .create_norm_act import get_norm_act_layer 14 | 15 | 16 | class SeparableConvNormAct(nn.Module): 17 | """ Separable Conv w/ trailing Norm and Activation 18 | """ 19 | def __init__( 20 | self, 21 | in_channels: int, 22 | out_channels: int, 23 | kernel_size: int = 3, 24 | stride: int = 1, 25 | dilation: int = 1, 26 | padding: str = '', 27 | bias: bool = False, 28 | channel_multiplier: float = 1.0, 29 | pw_kernel_size: int = 1, 30 | norm_layer: Type[nn.Module] = nn.BatchNorm2d, 31 | act_layer: Type[nn.Module] = nn.ReLU, 32 | apply_act: bool = True, 33 | drop_layer: Optional[Type[nn.Module]] = None, 34 | device=None, 35 | dtype=None, 36 | ): 37 | dd = {'device': device, 'dtype': dtype} 38 | super().__init__() 39 | 40 | self.conv_dw = create_conv2d( 41 | in_channels, 42 | int(in_channels * channel_multiplier), 43 | kernel_size, 44 | stride=stride, 45 | dilation=dilation, 46 | padding=padding, 47 | depthwise=True, 48 | **dd, 49 | ) 50 | 51 | self.conv_pw = create_conv2d( 52 | int(in_channels * channel_multiplier), 53 | out_channels, 54 | pw_kernel_size, 55 | padding=padding, 56 | bias=bias, 57 | **dd, 58 | ) 59 | 60 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 61 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 62 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs, **dd) 63 | 64 | @property 65 | def in_channels(self): 66 | return self.conv_dw.in_channels 67 | 68 | @property 69 | def out_channels(self): 70 | return self.conv_pw.out_channels 71 | 72 | def forward(self, x): 73 | x = self.conv_dw(x) 74 | x = self.conv_pw(x) 75 | x = self.bn(x) 76 | return x 77 | 78 | 79 | SeparableConvBnAct = SeparableConvNormAct 80 | 81 | 82 | class SeparableConv2d(nn.Module): 83 | """ Separable Conv 84 | """ 85 | def __init__( 86 | self, 87 | in_channels, 88 | out_channels, 89 | kernel_size=3, 90 | stride=1, 91 | dilation=1, 92 | padding='', 93 | bias=False, 94 | channel_multiplier=1.0, 95 | pw_kernel_size=1, 96 | device=None, 97 | dtype=None, 98 | ): 99 | dd = {'device': device, 'dtype': dtype} 100 | super().__init__() 101 | 102 | self.conv_dw = create_conv2d( 103 | in_channels, 104 | int(in_channels * channel_multiplier), 105 | kernel_size, 106 | stride=stride, 107 | dilation=dilation, 108 | padding=padding, 109 | depthwise=True, 110 | **dd, 111 | ) 112 | 113 | self.conv_pw = create_conv2d( 114 | int(in_channels * channel_multiplier), 115 | out_channels, 116 | pw_kernel_size, 117 | padding=padding, 118 | bias=bias, 119 | **dd, 120 | ) 121 | 122 | @property 123 | def in_channels(self): 124 | return self.conv_dw.in_channels 125 | 126 | @property 127 | def out_channels(self): 128 | return self.conv_pw.out_channels 129 | 130 | def forward(self, x): 131 | x = self.conv_dw(x) 132 | x = self.conv_pw(x) 133 | return x 134 | -------------------------------------------------------------------------------- /timm/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | from typing import Optional, Type, Union 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | 15 | from .helpers import make_divisible 16 | 17 | 18 | class RadixSoftmax(nn.Module): 19 | def __init__(self, radix: int, cardinality: int): 20 | super().__init__() 21 | self.radix = radix 22 | self.cardinality = cardinality 23 | 24 | def forward(self, x): 25 | batch = x.size(0) 26 | if self.radix > 1: 27 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 28 | x = F.softmax(x, dim=1) 29 | x = x.reshape(batch, -1) 30 | else: 31 | x = torch.sigmoid(x) 32 | return x 33 | 34 | 35 | class SplitAttn(nn.Module): 36 | """Split-Attention (aka Splat) 37 | """ 38 | def __init__( 39 | self, 40 | in_channels: int, 41 | out_channels: Optional[int] = None, 42 | kernel_size: int = 3, 43 | stride: int = 1, 44 | padding: Optional[int] = None, 45 | dilation: int = 1, 46 | groups: int = 1, 47 | bias: bool = False, 48 | radix: int = 2, 49 | rd_ratio: float = 0.25, 50 | rd_channels: Optional[int] = None, 51 | rd_divisor: int = 8, 52 | act_layer: Type[nn.Module] = nn.ReLU, 53 | norm_layer: Optional[Type[nn.Module]] = None, 54 | drop_layer: Optional[Type[nn.Module]] = None, 55 | **kwargs, 56 | ): 57 | dd = {'device': kwargs.pop('device', None), 'dtype': kwargs.pop('dtype', None)} 58 | super().__init__() 59 | out_channels = out_channels or in_channels 60 | self.radix = radix 61 | mid_chs = out_channels * radix 62 | if rd_channels is None: 63 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 64 | else: 65 | attn_chs = rd_channels * radix 66 | 67 | padding = kernel_size // 2 if padding is None else padding 68 | self.conv = nn.Conv2d( 69 | in_channels, 70 | mid_chs, 71 | kernel_size, 72 | stride, 73 | padding, 74 | dilation, 75 | groups=groups * radix, 76 | bias=bias, 77 | **kwargs, 78 | **dd, 79 | ) 80 | self.bn0 = norm_layer(mid_chs, **dd) if norm_layer else nn.Identity() 81 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 82 | self.act0 = act_layer(inplace=True) 83 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups, **dd) 84 | self.bn1 = norm_layer(attn_chs, **dd) if norm_layer else nn.Identity() 85 | self.act1 = act_layer(inplace=True) 86 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups, **dd) 87 | self.rsoftmax = RadixSoftmax(radix, groups) 88 | 89 | def forward(self, x): 90 | x = self.conv(x) 91 | x = self.bn0(x) 92 | x = self.drop(x) 93 | x = self.act0(x) 94 | 95 | B, RC, H, W = x.shape 96 | if self.radix > 1: 97 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 98 | x_gap = x.sum(dim=1) 99 | else: 100 | x_gap = x 101 | x_gap = x_gap.mean((2, 3), keepdim=True) 102 | x_gap = self.fc1(x_gap) 103 | x_gap = self.bn1(x_gap) 104 | x_gap = self.act1(x_gap) 105 | x_attn = self.fc2(x_gap) 106 | 107 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 108 | if self.radix > 1: 109 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 110 | else: 111 | out = x * x_attn 112 | return out.contiguous() 113 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | """ Plateau Scheduler 2 | 3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from typing import List, Optional 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class PlateauLRScheduler(Scheduler): 14 | """Decay the LR by a factor every time the validation loss plateaus.""" 15 | 16 | def __init__( 17 | self, 18 | optimizer, 19 | decay_rate=0.1, 20 | patience_t=10, 21 | verbose=True, 22 | threshold=1e-4, 23 | cooldown_t=0, 24 | warmup_t=0, 25 | warmup_lr_init=0, 26 | lr_min=0, 27 | mode='max', 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize=True, 34 | ): 35 | super().__init__( 36 | optimizer, 37 | 'lr', 38 | noise_range_t=noise_range_t, 39 | noise_type=noise_type, 40 | noise_pct=noise_pct, 41 | noise_std=noise_std, 42 | noise_seed=noise_seed, 43 | initialize=initialize, 44 | ) 45 | 46 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 47 | self.optimizer, 48 | patience=patience_t, 49 | factor=decay_rate, 50 | verbose=verbose, 51 | threshold=threshold, 52 | cooldown=cooldown_t, 53 | mode=mode, 54 | min_lr=lr_min 55 | ) 56 | 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | if self.warmup_t: 60 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 61 | super().update_groups(self.warmup_lr_init) 62 | else: 63 | self.warmup_steps = [1 for _ in self.base_values] 64 | self.restore_lr = None 65 | 66 | def state_dict(self): 67 | return { 68 | 'best': self.lr_scheduler.best, 69 | 'last_epoch': self.lr_scheduler.last_epoch, 70 | } 71 | 72 | def load_state_dict(self, state_dict): 73 | self.lr_scheduler.best = state_dict['best'] 74 | if 'last_epoch' in state_dict: 75 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 76 | 77 | # override the base class step fn completely 78 | def step(self, epoch, metric=None): 79 | if epoch <= self.warmup_t: 80 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 81 | super().update_groups(lrs) 82 | else: 83 | if self.restore_lr is not None: 84 | # restore actual LR from before our last noise perturbation before stepping base 85 | for i, param_group in enumerate(self.optimizer.param_groups): 86 | param_group['lr'] = self.restore_lr[i] 87 | self.restore_lr = None 88 | 89 | # step the base scheduler if metric given 90 | if metric is not None: 91 | self.lr_scheduler.step(metric, epoch) 92 | 93 | if self._is_apply_noise(epoch): 94 | self._apply_noise(epoch) 95 | 96 | def step_update(self, num_updates: int, metric: Optional[float] = None): 97 | return None 98 | 99 | def _apply_noise(self, epoch): 100 | noise = self._calculate_noise(epoch) 101 | 102 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 103 | # stepping of base scheduler 104 | restore_lr = [] 105 | for i, param_group in enumerate(self.optimizer.param_groups): 106 | old_lr = float(param_group['lr']) 107 | restore_lr.append(old_lr) 108 | new_lr = old_lr + old_lr * noise 109 | param_group['lr'] = new_lr 110 | self.restore_lr = restore_lr 111 | 112 | def _get_lr(self, t: int) -> List[float]: 113 | assert False, 'should not be called as step is overridden' 114 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | """ TanH Scheduler 2 | 3 | TanH schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | from typing import List 12 | 13 | from .scheduler import Scheduler 14 | 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class TanhLRScheduler(Scheduler): 20 | """ 21 | Hyberbolic-Tangent decay with restarts. 22 | This is described in the paper https://arxiv.org/abs/1806.01593 23 | """ 24 | 25 | def __init__( 26 | self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | lb: float = -7., 30 | ub: float = 3., 31 | lr_min: float = 0., 32 | cycle_mul: float = 1., 33 | cycle_decay: float = 1., 34 | cycle_limit: int = 1, 35 | warmup_t=0, 36 | warmup_lr_init=0, 37 | warmup_prefix=False, 38 | t_in_epochs=True, 39 | noise_range_t=None, 40 | noise_pct=0.67, 41 | noise_std=1.0, 42 | noise_seed=42, 43 | initialize=True, 44 | ) -> None: 45 | super().__init__( 46 | optimizer, 47 | param_group_field="lr", 48 | t_in_epochs=t_in_epochs, 49 | noise_range_t=noise_range_t, 50 | noise_pct=noise_pct, 51 | noise_std=noise_std, 52 | noise_seed=noise_seed, 53 | initialize=initialize, 54 | ) 55 | 56 | assert t_initial > 0 57 | assert lr_min >= 0 58 | assert lb < ub 59 | assert cycle_limit >= 0 60 | assert warmup_t >= 0 61 | assert warmup_lr_init >= 0 62 | self.lb = lb 63 | self.ub = ub 64 | self.t_initial = t_initial 65 | self.lr_min = lr_min 66 | self.cycle_mul = cycle_mul 67 | self.cycle_decay = cycle_decay 68 | self.cycle_limit = cycle_limit 69 | self.warmup_t = warmup_t 70 | self.warmup_lr_init = warmup_lr_init 71 | self.warmup_prefix = warmup_prefix 72 | if self.warmup_t: 73 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 74 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 75 | super().update_groups(self.warmup_lr_init) 76 | else: 77 | self.warmup_steps = [1 for _ in self.base_values] 78 | 79 | def _get_lr(self, t: int) -> List[float]: 80 | if t < self.warmup_t: 81 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 82 | else: 83 | if self.warmup_prefix: 84 | t = t - self.warmup_t 85 | 86 | if self.cycle_mul != 1: 87 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) 88 | t_i = self.cycle_mul ** i * self.t_initial 89 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial 90 | else: 91 | i = t // self.t_initial 92 | t_i = self.t_initial 93 | t_curr = t - (self.t_initial * i) 94 | 95 | if i < self.cycle_limit: 96 | gamma = self.cycle_decay ** i 97 | lr_max_values = [v * gamma for v in self.base_values] 98 | 99 | tr = t_curr / t_i 100 | lrs = [ 101 | self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 102 | for lr_max in lr_max_values 103 | ] 104 | else: 105 | lrs = [self.lr_min for _ in self.base_values] 106 | return lrs 107 | 108 | def get_cycle_length(self, cycles=0): 109 | cycles = max(1, cycles or self.cycle_limit) 110 | if self.cycle_mul == 1.0: 111 | t = self.t_initial * cycles 112 | else: 113 | t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) 114 | return t + self.warmup_t if self.warmup_prefix else t 115 | -------------------------------------------------------------------------------- /timm/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn 11 | from .eca import EcaModule, CecaModule 12 | from .gather_excite import GatherExcite 13 | from .global_context import GlobalContext 14 | from .halo_attn import HaloAttn 15 | from .lambda_layer import LambdaLayer 16 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 17 | from .selective_kernel import SelectiveKernel 18 | from .split_attn import SplitAttn 19 | from .squeeze_excite import SEModule, EffectiveSEModule 20 | 21 | 22 | def get_attn(attn_type): 23 | if isinstance(attn_type, torch.nn.Module): 24 | return attn_type 25 | module_cls = None 26 | if attn_type: 27 | if isinstance(attn_type, str): 28 | attn_type = attn_type.lower() 29 | # Lightweight attention modules (channel and/or coarse spatial). 30 | # Typically added to existing network architecture blocks in addition to existing convolutions. 31 | if attn_type == 'se': 32 | module_cls = SEModule 33 | elif attn_type == 'ese': 34 | module_cls = EffectiveSEModule 35 | elif attn_type == 'eca': 36 | module_cls = EcaModule 37 | elif attn_type == 'ecam': 38 | module_cls = partial(EcaModule, use_mlp=True) 39 | elif attn_type == 'ceca': 40 | module_cls = CecaModule 41 | elif attn_type == 'ge': 42 | module_cls = GatherExcite 43 | elif attn_type == 'gc': 44 | module_cls = GlobalContext 45 | elif attn_type == 'gca': 46 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 47 | elif attn_type == 'cbam': 48 | module_cls = CbamModule 49 | elif attn_type == 'lcbam': 50 | module_cls = LightCbamModule 51 | elif attn_type == 'coord': 52 | module_cls = CoordAttn 53 | elif attn_type == 'scoord': 54 | module_cls = SimpleCoordAttn 55 | elif attn_type == 'ela': 56 | module_cls = EfficientLocalAttn 57 | elif attn_type == 'strip': 58 | module_cls = StripAttn 59 | 60 | # Attention / attention-like modules w/ significant params 61 | # Typically replace some of the existing workhorse convs in a network architecture. 62 | # All of these accept a stride argument and can spatially downsample the input. 63 | elif attn_type == 'sk': 64 | module_cls = SelectiveKernel 65 | elif attn_type == 'splat': 66 | module_cls = SplitAttn 67 | 68 | # Self-attention / attention-like modules w/ significant compute and/or params 69 | # Typically replace some of the existing workhorse convs in a network architecture. 70 | # All of these accept a stride argument and can spatially downsample the input. 71 | elif attn_type == 'lambda': 72 | return LambdaLayer 73 | elif attn_type == 'bottleneck': 74 | return BottleneckAttn 75 | elif attn_type == 'halo': 76 | return HaloAttn 77 | elif attn_type == 'nl': 78 | module_cls = NonLocalAttn 79 | elif attn_type == 'bat': 80 | module_cls = BatNonLocalAttn 81 | 82 | # Woops! 83 | else: 84 | assert False, "Invalid attn module (%s)" % attn_type 85 | elif isinstance(attn_type, bool): 86 | if attn_type: 87 | module_cls = SEModule 88 | else: 89 | module_cls = attn_type 90 | return module_cls 91 | 92 | 93 | def create_attn(attn_type, channels, **kwargs): 94 | module_cls = get_attn(attn_type) 95 | if module_cls is not None: 96 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 97 | return module_cls(channels, **kwargs) 98 | return None 99 | --------------------------------------------------------------------------------