├── timm
├── py.typed
├── version.py
├── data
│ ├── readers
│ │ ├── __init__.py
│ │ ├── shared_count.py
│ │ ├── reader.py
│ │ ├── class_map.py
│ │ ├── img_extensions.py
│ │ ├── reader_factory.py
│ │ ├── reader_image_tar.py
│ │ ├── reader_hfds.py
│ │ └── reader_image_folder.py
│ ├── constants.py
│ ├── _info
│ │ ├── mini_imagenet_indices.txt
│ │ ├── mini_imagenet_synsets.txt
│ │ ├── imagenet_r_indices.txt
│ │ ├── imagenet_a_indices.txt
│ │ ├── imagenet_a_synsets.txt
│ │ └── imagenet_r_synsets.txt
│ ├── __init__.py
│ ├── real_labels.py
│ └── dataset_info.py
├── models
│ ├── hub.py
│ ├── factory.py
│ ├── features.py
│ ├── registry.py
│ ├── fx_features.py
│ ├── helpers.py
│ ├── layers
│ │ └── __init__.py
│ ├── _features_fx.py
│ └── _pretrained.py
├── utils
│ ├── random.py
│ ├── clip_grad.py
│ ├── __init__.py
│ ├── metrics.py
│ ├── log.py
│ ├── misc.py
│ ├── summary.py
│ ├── agc.py
│ ├── decay_batch.py
│ ├── jit.py
│ ├── cuda.py
│ ├── attention_extract.py
│ └── onnx.py
├── loss
│ ├── __init__.py
│ ├── cross_entropy.py
│ ├── jsd.py
│ ├── binary_cross_entropy.py
│ └── asymmetric_loss.py
├── scheduler
│ ├── __init__.py
│ ├── step_lr.py
│ ├── multistep_lr.py
│ ├── plateau_lr.py
│ └── tanh_lr.py
├── layers
│ ├── trace_utils.py
│ ├── pool1d.py
│ ├── linear.py
│ ├── typing.py
│ ├── space_to_depth.py
│ ├── helpers.py
│ ├── format.py
│ ├── grn.py
│ ├── layer_scale.py
│ ├── create_conv2d.py
│ ├── grid.py
│ ├── median_pool.py
│ ├── test_time_pool.py
│ ├── create_norm.py
│ ├── mixed_conv2d.py
│ ├── _fx.py
│ ├── interpolate.py
│ ├── pos_embed.py
│ ├── global_context.py
│ ├── filter_response_norm.py
│ ├── patch_dropout.py
│ ├── conv_bn_act.py
│ ├── padding.py
│ ├── split_batchnorm.py
│ ├── inplace_abn.py
│ ├── pool2d_same.py
│ ├── separable_conv.py
│ ├── split_attn.py
│ └── create_attn.py
├── optim
│ ├── optim_factory.py
│ ├── _types.py
│ ├── __init__.py
│ ├── sgdp.py
│ └── lookahead.py
├── task
│ ├── __init__.py
│ ├── classification.py
│ └── task.py
└── __init__.py
├── tests
└── __init__.py
├── .gitattributes
├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ ├── trufflehog.yml
│ ├── upload_pr_documentation.yml
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ └── tests.yml
├── requirements-dev.txt
├── MANIFEST.in
├── hubconf.py
├── requirements.txt
├── distributed_train.sh
├── hfdocs
├── source
│ ├── reference
│ │ ├── models.mdx
│ │ ├── data.mdx
│ │ ├── schedulers.mdx
│ │ └── optimizers.mdx
│ ├── index.mdx
│ ├── hf_hub.mdx
│ ├── installation.mdx
│ └── hparams.mdx
└── README.md
├── setup.cfg
├── CITATION.cff
├── CLAUDE.md
├── .gitignore
├── pyproject.toml
├── UPGRADING.md
└── results
└── generate_csv_results.py
/timm/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/timm/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.0.23.dev0'
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-documentation
2 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 | github: rwightman
3 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-timeout
3 | pytest-xdist
4 | pytest-forked
5 | expecttest
6 |
--------------------------------------------------------------------------------
/timm/data/readers/__init__.py:
--------------------------------------------------------------------------------
1 | from .reader_factory import create_reader
2 | from .img_extensions import *
3 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include timm/models/_pruned/*.txt
2 | include timm/data/_info/*.txt
3 | include timm/data/_info/*.json
4 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | dependencies = ['torch']
2 | import timm
3 | globals().update(timm.models._registry._model_entrypoints)
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7
2 | torchvision
3 | pyyaml
4 | huggingface_hub>=0.17.0
5 | safetensors>=0.2
6 | numpy
7 |
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@"
5 |
6 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/models.mdx:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | [[autodoc]] timm.create_model
4 |
5 | [[autodoc]] timm.list_models
6 |
--------------------------------------------------------------------------------
/timm/models/hub.py:
--------------------------------------------------------------------------------
1 | from ._hub import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/factory.py:
--------------------------------------------------------------------------------
1 | from ._factory import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/features.py:
--------------------------------------------------------------------------------
1 | from ._features import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/registry.py:
--------------------------------------------------------------------------------
1 | from ._registry import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/fx_features.py:
--------------------------------------------------------------------------------
1 | from ._features_fx import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
5 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [dist_conda]
2 |
3 | conda_name_differences = 'torch:pytorch'
4 | channels = pytorch
5 | noarch = True
6 |
7 | [metadata]
8 |
9 | url = "https://github.com/huggingface/pytorch-image-models"
--------------------------------------------------------------------------------
/hfdocs/source/reference/data.mdx:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | [[autodoc]] timm.data.create_dataset
4 |
5 | [[autodoc]] timm.data.create_loader
6 |
7 | [[autodoc]] timm.data.create_transform
8 |
9 | [[autodoc]] timm.data.resolve_data_config
--------------------------------------------------------------------------------
/timm/utils/random.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def random_seed(seed=42, rank=0):
7 | torch.manual_seed(seed + rank)
8 | np.random.seed(seed + rank)
9 | random.seed(seed + rank)
10 |
--------------------------------------------------------------------------------
/timm/models/helpers.py:
--------------------------------------------------------------------------------
1 | from ._builder import *
2 | from ._helpers import *
3 | from ._manipulate import *
4 | from ._prune import *
5 |
6 | import warnings
7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
8 |
--------------------------------------------------------------------------------
/timm/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
2 | from .binary_cross_entropy import BinaryCrossEntropy
3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
4 | from .jsd import JsdCrossEntropy
5 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Community Discussions
4 | url: https://github.com/rwightman/pytorch-image-models/discussions
5 | about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions.
6 |
--------------------------------------------------------------------------------
/hfdocs/README.md:
--------------------------------------------------------------------------------
1 | # Hugging Face Timm Docs
2 |
3 | ## Getting Started
4 |
5 | ```
6 | pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder
7 | pip install watchdog black
8 | ```
9 |
10 | ## Preview the Docs Locally
11 |
12 | ```
13 | doc-builder preview timm hfdocs/source
14 | ```
15 |
--------------------------------------------------------------------------------
/.github/workflows/trufflehog.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 |
4 | name: Secret Leaks
5 |
6 | jobs:
7 | trufflehog:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v4
12 | with:
13 | fetch-depth: 0
14 | - name: Secret Scanning
15 | uses: trufflesecurity/trufflehog@main
16 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | message: "If you use this software, please cite it as below."
2 | title: "PyTorch Image Models"
3 | version: "1.2.2"
4 | doi: "10.5281/zenodo.4414861"
5 | authors:
6 | - family-names: Wightman
7 | given-names: Ross
8 | version: 1.0.11
9 | year: "2019"
10 | url: "https://github.com/huggingface/pytorch-image-models"
11 | license: "Apache 2.0"
--------------------------------------------------------------------------------
/timm/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_lr import CosineLRScheduler
2 | from .multistep_lr import MultiStepLRScheduler
3 | from .plateau_lr import PlateauLRScheduler
4 | from .poly_lr import PolyLRScheduler
5 | from .step_lr import StepLRScheduler
6 | from .tanh_lr import TanhLRScheduler
7 |
8 | from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs
9 |
--------------------------------------------------------------------------------
/timm/data/readers/shared_count.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Value
2 |
3 |
4 | class SharedCount:
5 | def __init__(self, epoch: int = 0):
6 | self.shared_epoch = Value('i', epoch)
7 |
8 | @property
9 | def value(self):
10 | return self.shared_epoch.value
11 |
12 | @value.setter
13 | def value(self, epoch):
14 | self.shared_epoch.value = epoch
15 |
--------------------------------------------------------------------------------
/timm/layers/trace_utils.py:
--------------------------------------------------------------------------------
1 | try:
2 | from torch import _assert
3 | except ImportError:
4 | def _assert(condition: bool, message: str):
5 | assert condition, message
6 |
7 |
8 | def _float_to_int(x: float) -> int:
9 | """
10 | Symbolic tracing helper to substitute for inbuilt `int`.
11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy`
12 | """
13 | return int(x)
14 |
--------------------------------------------------------------------------------
/timm/optim/optim_factory.py:
--------------------------------------------------------------------------------
1 | # lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
2 |
3 | from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
4 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
5 |
6 | import warnings
7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)
8 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: timm
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/timm/data/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_CROP_PCT = 0.875
2 | DEFAULT_CROP_MODE = 'center'
3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
11 |
--------------------------------------------------------------------------------
/timm/data/readers/reader.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class Reader:
5 | def __init__(self):
6 | pass
7 |
8 | @abstractmethod
9 | def _filename(self, index, basename=False, absolute=False):
10 | pass
11 |
12 | def filename(self, index, basename=False, absolute=False):
13 | return self._filename(index, basename=basename, absolute=absolute)
14 |
15 | def filenames(self, basename=False, absolute=False):
16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17 |
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.sha }}
15 | package: pytorch-image-models
16 | package_name: timm
17 | path_to_docs: pytorch-image-models/hfdocs/source
18 | version_tag_suffix: ""
19 | secrets:
20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
21 |
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 |
6 | concurrency:
7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.event.pull_request.head.sha }}
15 | pr_number: ${{ github.event.number }}
16 | package: pytorch-image-models
17 | package_name: timm
18 | path_to_docs: pytorch-image-models/hfdocs/source
19 | version_tag_suffix: ""
20 |
--------------------------------------------------------------------------------
/timm/task/__init__.py:
--------------------------------------------------------------------------------
1 | """Training task abstractions for timm.
2 |
3 | This module provides task-based abstractions for training loops where each task
4 | encapsulates both the forward pass and loss computation, returning a dictionary
5 | with loss components and outputs for logging.
6 | """
7 | from .task import TrainingTask
8 | from .classification import ClassificationTask
9 | from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask
10 |
11 | __all__ = [
12 | 'TrainingTask',
13 | 'ClassificationTask',
14 | 'DistillationTeacher',
15 | 'LogitDistillationTask',
16 | 'FeatureDistillationTask',
17 | ]
18 |
--------------------------------------------------------------------------------
/timm/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__ as __version__
2 | from .layers import (
3 | is_scriptable as is_scriptable,
4 | is_exportable as is_exportable,
5 | set_scriptable as set_scriptable,
6 | set_exportable as set_exportable,
7 | )
8 | from .models import (
9 | create_model as create_model,
10 | list_models as list_models,
11 | list_pretrained as list_pretrained,
12 | is_model as is_model,
13 | list_modules as list_modules,
14 | model_entrypoint as model_entrypoint,
15 | is_model_pretrained as is_model_pretrained,
16 | get_pretrained_cfg as get_pretrained_cfg,
17 | get_pretrained_cfg_value as get_pretrained_cfg_value,
18 | )
19 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/schedulers.mdx:
--------------------------------------------------------------------------------
1 | # Learning Rate Schedulers
2 |
3 | This page contains the API reference documentation for learning rate schedulers included in `timm`.
4 |
5 | ## Schedulers
6 |
7 | ### Factory functions
8 |
9 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler
10 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2
11 |
12 | ### Scheduler Classes
13 |
14 | [[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler
15 | [[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler
16 | [[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler
17 | [[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler
18 | [[autodoc]] timm.scheduler.step_lr.StepLRScheduler
19 | [[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler
20 |
--------------------------------------------------------------------------------
/timm/layers/pool1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def global_pool_nlc(
5 | x: torch.Tensor,
6 | pool_type: str = 'token',
7 | num_prefix_tokens: int = 1,
8 | reduce_include_prefix: bool = False,
9 | ):
10 | if not pool_type:
11 | return x
12 |
13 | if pool_type == 'token':
14 | x = x[:, 0] # class token
15 | else:
16 | x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
17 | if pool_type == 'avg':
18 | x = x.mean(dim=1)
19 | elif pool_type == 'avgmax':
20 | x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
21 | elif pool_type == 'max':
22 | x = x.amax(dim=1)
23 | else:
24 | assert not pool_type, f'Unknown pool type {pool_type}'
25 |
26 | return x
--------------------------------------------------------------------------------
/timm/layers/linear.py:
--------------------------------------------------------------------------------
1 | """ Linear layer (alternate definition)
2 | """
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 |
7 |
8 | class Linear(nn.Linear):
9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10 |
11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13 | """
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | if torch.jit.is_scripting():
16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18 | else:
19 | return F.linear(input, self.weight, self.bias)
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project. Hparam requests, training help are not feature requests.
4 | The discussion forum is available for asking questions or seeking help from the community.
5 | title: "[FEATURE] Feature title..."
6 | labels: enhancement
7 | assignees: ''
8 |
9 | ---
10 |
11 | **Is your feature request related to a problem? Please describe.**
12 | A clear and concise description of what the problem is.
13 |
14 | **Describe the solution you'd like**
15 | A clear and concise description of what you want to happen.
16 |
17 | **Describe alternatives you've considered**
18 | A clear and concise description of any alternative solutions or features you've considered.
19 |
20 | **Additional context**
21 | Add any other context or screenshots about the feature request here.
22 |
--------------------------------------------------------------------------------
/timm/optim/_types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Iterable, Union, Protocol, Type
2 | try:
3 | from typing import TypeAlias
4 | except ImportError:
5 | from typing_extensions import TypeAlias
6 | try:
7 | from typing import TypeVar
8 | except ImportError:
9 | from typing_extensions import TypeVar
10 |
11 | import torch
12 | import torch.optim
13 |
14 | try:
15 | from torch.optim.optimizer import ParamsT
16 | except (ImportError, TypeError):
17 | ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
18 |
19 |
20 | OptimType = Type[torch.optim.Optimizer]
21 |
22 |
23 | class OptimizerCallable(Protocol):
24 | """Protocol for optimizer constructor signatures."""
25 |
26 | def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
27 |
28 |
29 | __all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']
--------------------------------------------------------------------------------
/timm/utils/clip_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from timm.utils.agc import adaptive_clip_grad
4 |
5 |
6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
7 | """ Dispatch to gradient clipping method
8 |
9 | Args:
10 | parameters (Iterable): model parameters to clip
11 | value (float): clipping value/factor/norm, mode dependant
12 | mode (str): clipping mode, one of 'norm', 'value', 'agc'
13 | norm_type (float): p-norm, default 2.0
14 | """
15 | if mode == 'norm':
16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
17 | elif mode == 'value':
18 | torch.nn.utils.clip_grad_value_(parameters, value)
19 | elif mode == 'agc':
20 | adaptive_clip_grad(parameters, value, norm_type=norm_type)
21 | else:
22 | assert False, f"Unknown clip mode ({mode})."
23 |
24 |
--------------------------------------------------------------------------------
/timm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .agc import adaptive_clip_grad
2 | from .attention_extract import AttentionExtract
3 | from .checkpoint_saver import CheckpointSaver
4 | from .clip_grad import dispatch_clip_grad
5 | from .cuda import ApexScaler, NativeScaler
6 | from .decay_batch import decay_batch_step, check_batch_size_retry
7 | from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
8 | world_info_from_env, is_distributed_env, is_primary
9 | from .jit import set_jit_legacy, set_jit_fuser
10 | from .log import setup_default_logging, FormatterNoInfo
11 | from .metrics import AverageMeter, accuracy
12 | from .misc import natural_key, add_bool_arg, ParseKwargs
13 | from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
14 | from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3
15 | from .random import random_seed
16 | from .summary import update_summary, get_outdir
17 |
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md - PyTorch Image Models (timm)
2 |
3 | ## Build/Test Commands
4 | - Install: `python -m pip install -e .`
5 | - Run tests: `pytest tests/`
6 | - Run specific test: `pytest tests/test_models.py::test_specific_function -v`
7 | - Run tests in parallel: `pytest -n 4 tests/`
8 | - Filter tests: `pytest -k "substring-to-match" tests/`
9 |
10 | ## Code Style Guidelines
11 | - Line length: 120 chars
12 | - Indentation: 4-space hanging indents, arguments should have an extra level of indent, use 'sadface' (closing parenthesis and colon on a separate line)
13 | - Typing: Use PEP484 type annotations in function signatures
14 | - Docstrings: Google style (do not duplicate type annotations and defaults)
15 | - Imports: Standard library first, then third-party, then local
16 | - Function naming: snake_case
17 | - Class naming: PascalCase
18 | - Error handling: Use try/except with specific exceptions
19 | - Conditional expressions: Use parentheses for complex expressions
--------------------------------------------------------------------------------
/timm/data/_info/mini_imagenet_indices.txt:
--------------------------------------------------------------------------------
1 | 12
2 | 15
3 | 51
4 | 64
5 | 70
6 | 96
7 | 99
8 | 107
9 | 111
10 | 121
11 | 149
12 | 166
13 | 173
14 | 176
15 | 207
16 | 214
17 | 228
18 | 242
19 | 244
20 | 245
21 | 249
22 | 251
23 | 256
24 | 266
25 | 270
26 | 275
27 | 279
28 | 291
29 | 299
30 | 301
31 | 306
32 | 310
33 | 359
34 | 364
35 | 392
36 | 403
37 | 412
38 | 427
39 | 440
40 | 454
41 | 471
42 | 476
43 | 478
44 | 484
45 | 494
46 | 502
47 | 503
48 | 507
49 | 519
50 | 524
51 | 533
52 | 538
53 | 546
54 | 553
55 | 556
56 | 567
57 | 569
58 | 584
59 | 597
60 | 602
61 | 604
62 | 605
63 | 629
64 | 655
65 | 657
66 | 659
67 | 683
68 | 687
69 | 702
70 | 709
71 | 713
72 | 735
73 | 741
74 | 758
75 | 779
76 | 781
77 | 800
78 | 801
79 | 807
80 | 815
81 | 819
82 | 847
83 | 854
84 | 858
85 | 860
86 | 880
87 | 881
88 | 883
89 | 909
90 | 912
91 | 914
92 | 919
93 | 925
94 | 927
95 | 934
96 | 950
97 | 972
98 | 973
99 | 997
100 | 998
101 |
--------------------------------------------------------------------------------
/timm/data/readers/class_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 |
5 | def load_class_map(map_or_filename, root=''):
6 | if isinstance(map_or_filename, dict):
7 | assert dict, 'class_map dict must be non-empty'
8 | return map_or_filename
9 | class_map_path = map_or_filename
10 | if not os.path.exists(class_map_path):
11 | class_map_path = os.path.join(root, class_map_path)
12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
14 | if class_map_ext == '.txt':
15 | with open(class_map_path) as f:
16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)}
17 | elif class_map_ext == '.pkl':
18 | with open(class_map_path, 'rb') as f:
19 | class_to_idx = pickle.load(f)
20 | else:
21 | assert False, f'Unsupported class map file extension ({class_map_ext}).'
22 | return class_to_idx
23 |
24 |
--------------------------------------------------------------------------------
/timm/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """ Eval metrics and related
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 |
7 | class AverageMeter:
8 | """Computes and stores the average and current value"""
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the accuracy over the k top predictions for the specified values of k"""
27 | maxk = min(max(topk), output.size()[1])
28 | batch_size = target.size(0)
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
33 |
--------------------------------------------------------------------------------
/timm/layers/typing.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from functools import wraps
3 | from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager
4 |
5 | import torch
6 |
7 | __all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"]
8 |
9 |
10 | LayerType = Union[str, Callable, Type[torch.nn.Module]]
11 | PadType = Union[str, int, Tuple[int, int]]
12 |
13 | F = TypeVar("F", bound=Callable[..., object])
14 |
15 |
16 | @overload
17 | def nullwrap(fn: F) -> F: ... # decorator form
18 |
19 | @overload
20 | def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form
21 |
22 | def nullwrap(fn: Optional[F] = None):
23 | # as a context manager
24 | if fn is None:
25 | return nullcontext() # `with nullwrap():`
26 |
27 | # as a decorator
28 | @wraps(fn)
29 | def wrapper(*args, **kwargs):
30 | return fn(*args, **kwargs)
31 | return wrapper # `@nullwrap`
32 |
33 |
34 | disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap
35 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting
4 | features, the discussion forum is available for asking questions or seeking help
5 | from the community.
6 | title: "[BUG] Issue title..."
7 | labels: bug
8 | assignees: rwightman
9 |
10 | ---
11 |
12 | **Describe the bug**
13 | A clear and concise description of what the bug is.
14 |
15 | **To Reproduce**
16 | Steps to reproduce the behavior:
17 | 1.
18 | 2.
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. Windows 10, Ubuntu 18.04]
28 | - This repository version [e.g. pip 0.3.1 or commit ref]
29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0]
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/timm/utils/log.py:
--------------------------------------------------------------------------------
1 | """ Logging helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import logging
6 | import logging.handlers
7 |
8 |
9 | class FormatterNoInfo(logging.Formatter):
10 | def __init__(self, fmt='%(levelname)s: %(message)s'):
11 | logging.Formatter.__init__(self, fmt)
12 |
13 | def format(self, record):
14 | if record.levelno == logging.INFO:
15 | return str(record.getMessage())
16 | return logging.Formatter.format(self, record)
17 |
18 |
19 | def setup_default_logging(default_level=logging.INFO, log_path=''):
20 | console_handler = logging.StreamHandler()
21 | console_handler.setFormatter(FormatterNoInfo())
22 | logging.root.addHandler(console_handler)
23 | logging.root.setLevel(default_level)
24 | if log_path:
25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
27 | file_handler.setFormatter(file_formatter)
28 | logging.root.addHandler(file_handler)
29 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/optimizers.mdx:
--------------------------------------------------------------------------------
1 | # Optimization
2 |
3 | This page contains the API reference documentation for learning rate optimizers included in `timm`.
4 |
5 | ## Optimizers
6 |
7 | ### Factory functions
8 |
9 | [[autodoc]] timm.optim.create_optimizer_v2
10 | [[autodoc]] timm.optim.list_optimizers
11 | [[autodoc]] timm.optim.get_optimizer_class
12 |
13 | ### Optimizer Classes
14 |
15 | [[autodoc]] timm.optim.adabelief.AdaBelief
16 | [[autodoc]] timm.optim.adafactor.Adafactor
17 | [[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
18 | [[autodoc]] timm.optim.adahessian.Adahessian
19 | [[autodoc]] timm.optim.adamp.AdamP
20 | [[autodoc]] timm.optim.adan.Adan
21 | [[autodoc]] timm.optim.adopt.Adopt
22 | [[autodoc]] timm.optim.lamb.Lamb
23 | [[autodoc]] timm.optim.laprop.LaProp
24 | [[autodoc]] timm.optim.lars.Lars
25 | [[autodoc]] timm.optim.lion.Lion
26 | [[autodoc]] timm.optim.lookahead.Lookahead
27 | [[autodoc]] timm.optim.madgrad.MADGRAD
28 | [[autodoc]] timm.optim.mars.Mars
29 | [[autodoc]] timm.optim.nadamw.NAdamW
30 | [[autodoc]] timm.optim.nvnovograd.NvNovoGrad
31 | [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
32 | [[autodoc]] timm.optim.sgdp.SGDP
33 | [[autodoc]] timm.optim.sgdw.SGDW
--------------------------------------------------------------------------------
/timm/layers/space_to_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpaceToDepth(nn.Module):
6 | bs: torch.jit.Final[int]
7 |
8 | def __init__(self, block_size: int = 4):
9 | super().__init__()
10 | assert block_size == 4
11 | self.bs = block_size
12 |
13 | def forward(self, x):
14 | N, C, H, W = x.size()
15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
18 | return x
19 |
20 |
21 | class DepthToSpace(nn.Module):
22 |
23 | def __init__(self, block_size):
24 | super().__init__()
25 | self.bs = block_size
26 |
27 | def forward(self, x):
28 | N, C, H, W = x.size()
29 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
30 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
31 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
32 | return x
33 |
--------------------------------------------------------------------------------
/timm/utils/misc.py:
--------------------------------------------------------------------------------
1 | """ Misc utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import argparse
6 | import ast
7 | import re
8 |
9 |
10 | def natural_key(string_):
11 | """See http://www.codinghorror.com/blog/archives/001018.html"""
12 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
13 |
14 |
15 | def add_bool_arg(parser, name, default=False, help=''):
16 | dest_name = name.replace('-', '_')
17 | group = parser.add_mutually_exclusive_group(required=False)
18 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
19 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
20 | parser.set_defaults(**{dest_name: default})
21 |
22 |
23 | class ParseKwargs(argparse.Action):
24 | def __call__(self, parser, namespace, values, option_string=None):
25 | kw = {}
26 | for value in values:
27 | key, value = value.split('=')
28 | try:
29 | kw[key] = ast.literal_eval(value)
30 | except ValueError:
31 | kw[key] = str(value) # fallback to string (avoid need to escape on command line)
32 | setattr(namespace, self.dest, kw)
33 |
--------------------------------------------------------------------------------
/timm/layers/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13 | return tuple(x)
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
32 |
33 |
34 | def extend_tuple(x, n):
35 | # pads a tuple to specified n by padding with last value
36 | if not isinstance(x, (tuple, list)):
37 | x = (x,)
38 | else:
39 | x = tuple(x)
40 | pad_n = n - len(x)
41 | if pad_n <= 0:
42 | return x[:n]
43 | return x + (x[-1],) * pad_n
44 |
--------------------------------------------------------------------------------
/timm/loss/cross_entropy.py:
--------------------------------------------------------------------------------
1 | """ Cross Entropy w/ smoothing or soft targets
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class LabelSmoothingCrossEntropy(nn.Module):
12 | """ NLL loss with label smoothing.
13 | """
14 | def __init__(self, smoothing=0.1):
15 | super(LabelSmoothingCrossEntropy, self).__init__()
16 | assert smoothing < 1.0
17 | self.smoothing = smoothing
18 | self.confidence = 1. - smoothing
19 |
20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
21 | logprobs = F.log_softmax(x, dim=-1)
22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
23 | nll_loss = nll_loss.squeeze(1)
24 | smooth_loss = -logprobs.mean(dim=-1)
25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
26 | return loss.mean()
27 |
28 |
29 | class SoftTargetCrossEntropy(nn.Module):
30 |
31 | def __init__(self):
32 | super(SoftTargetCrossEntropy, self).__init__()
33 |
34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
36 | return loss.mean()
37 |
--------------------------------------------------------------------------------
/timm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2 | rand_augment_transform, auto_augment_transform
3 | from .config import resolve_data_config, resolve_model_data_config
4 | from .constants import *
5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6 | from .dataset_factory import create_dataset
7 | from .dataset_info import DatasetInfo, CustomDatasetInfo
8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset
9 | from .loader import create_loader
10 | from .mixup import Mixup, FastCollateMixup
11 | from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size
12 | from .naflex_loader import create_naflex_loader
13 | from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
14 | from .naflex_transforms import (
15 | ResizeToSequence,
16 | CenterCropToSequence,
17 | RandomCropToSequence,
18 | RandomResizedCropToSequence,
19 | ResizeKeepRatioToSequence,
20 | Patchify,
21 | patchify_image,
22 | )
23 | from .readers import create_reader
24 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
25 | from .real_labels import RealLabelsImagenet
26 | from .transforms import *
27 | from .transforms_factory import create_transform
28 |
--------------------------------------------------------------------------------
/timm/layers/format.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Union
3 |
4 | import torch
5 |
6 |
7 | class Format(str, Enum):
8 | NCHW = 'NCHW'
9 | NHWC = 'NHWC'
10 | NCL = 'NCL'
11 | NLC = 'NLC'
12 |
13 |
14 | FormatT = Union[str, Format]
15 |
16 |
17 | def get_spatial_dim(fmt: FormatT):
18 | fmt = Format(fmt)
19 | if fmt is Format.NLC:
20 | dim = (1,)
21 | elif fmt is Format.NCL:
22 | dim = (2,)
23 | elif fmt is Format.NHWC:
24 | dim = (1, 2)
25 | else:
26 | dim = (2, 3)
27 | return dim
28 |
29 |
30 | def get_channel_dim(fmt: FormatT):
31 | fmt = Format(fmt)
32 | if fmt is Format.NHWC:
33 | dim = 3
34 | elif fmt is Format.NLC:
35 | dim = 2
36 | else:
37 | dim = 1
38 | return dim
39 |
40 |
41 | def nchw_to(x: torch.Tensor, fmt: Format):
42 | if fmt == Format.NHWC:
43 | x = x.permute(0, 2, 3, 1)
44 | elif fmt == Format.NLC:
45 | x = x.flatten(2).transpose(1, 2)
46 | elif fmt == Format.NCL:
47 | x = x.flatten(2)
48 | return x
49 |
50 |
51 | def nhwc_to(x: torch.Tensor, fmt: Format):
52 | if fmt == Format.NCHW:
53 | x = x.permute(0, 3, 1, 2)
54 | elif fmt == Format.NLC:
55 | x = x.flatten(1, 2)
56 | elif fmt == Format.NCL:
57 | x = x.flatten(1, 2).transpose(1, 2)
58 | return x
59 |
--------------------------------------------------------------------------------
/timm/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adabelief import AdaBelief
2 | from .adafactor import Adafactor
3 | from .adafactor_bv import AdafactorBigVision
4 | from .adahessian import Adahessian
5 | from .adamp import AdamP
6 | from .adamw import AdamWLegacy
7 | from .adan import Adan
8 | from .adopt import Adopt
9 | from .lamb import Lamb
10 | from .laprop import LaProp
11 | from .lars import Lars
12 | from .lion import Lion
13 | from .lookahead import Lookahead
14 | from .madgrad import MADGRAD
15 | from .mars import Mars
16 | from .muon import Muon
17 | from .nadam import NAdamLegacy
18 | from .nadamw import NAdamW
19 | from .nvnovograd import NvNovoGrad
20 | from .radam import RAdamLegacy
21 | from .rmsprop_tf import RMSpropTF
22 | from .sgdp import SGDP
23 | from .sgdw import SGDW
24 |
25 | # bring common torch.optim Optimizers into timm.optim namespace for consistency
26 | from torch.optim import Adadelta, Adagrad, Adamax, Adam, AdamW, RMSprop, SGD
27 | try:
28 | # in case any very old torch versions being used
29 | from torch.optim import NAdam, RAdam
30 | except ImportError:
31 | pass
32 |
33 | from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
34 | create_optimizer_v2, create_optimizer, optimizer_kwargs
35 | from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
36 |
--------------------------------------------------------------------------------
/timm/utils/summary.py:
--------------------------------------------------------------------------------
1 | """ Summary utilities
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import csv
6 | import os
7 | from collections import OrderedDict
8 | try:
9 | import wandb
10 | except ImportError:
11 | pass
12 |
13 |
14 | def get_outdir(path, *paths, inc=False):
15 | outdir = os.path.join(path, *paths)
16 | if not os.path.exists(outdir):
17 | os.makedirs(outdir)
18 | elif inc:
19 | count = 1
20 | outdir_inc = outdir + '-' + str(count)
21 | while os.path.exists(outdir_inc):
22 | count = count + 1
23 | outdir_inc = outdir + '-' + str(count)
24 | assert count < 100
25 | outdir = outdir_inc
26 | os.makedirs(outdir)
27 | return outdir
28 |
29 |
30 | def update_summary(
31 | epoch,
32 | train_metrics,
33 | eval_metrics,
34 | filename,
35 | lr=None,
36 | write_header=False,
37 | log_wandb=False,
38 | ):
39 | rowd = OrderedDict(epoch=epoch)
40 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
41 | if eval_metrics:
42 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
43 | if lr is not None:
44 | rowd['lr'] = lr
45 | if log_wandb:
46 | wandb.log(rowd)
47 | with open(filename, mode='a') as cf:
48 | dw = csv.DictWriter(cf, fieldnames=rowd.keys())
49 | if write_header: # first iteration (epoch == 1 can't be used)
50 | dw.writeheader()
51 | dw.writerow(rowd)
52 |
--------------------------------------------------------------------------------
/timm/data/_info/mini_imagenet_synsets.txt:
--------------------------------------------------------------------------------
1 | n01532829
2 | n01558993
3 | n01704323
4 | n01749939
5 | n01770081
6 | n01843383
7 | n01855672
8 | n01910747
9 | n01930112
10 | n01981276
11 | n02074367
12 | n02089867
13 | n02091244
14 | n02091831
15 | n02099601
16 | n02101006
17 | n02105505
18 | n02108089
19 | n02108551
20 | n02108915
21 | n02110063
22 | n02110341
23 | n02111277
24 | n02113712
25 | n02114548
26 | n02116738
27 | n02120079
28 | n02129165
29 | n02138441
30 | n02165456
31 | n02174001
32 | n02219486
33 | n02443484
34 | n02457408
35 | n02606052
36 | n02687172
37 | n02747177
38 | n02795169
39 | n02823428
40 | n02871525
41 | n02950826
42 | n02966193
43 | n02971356
44 | n02981792
45 | n03017168
46 | n03047690
47 | n03062245
48 | n03075370
49 | n03127925
50 | n03146219
51 | n03207743
52 | n03220513
53 | n03272010
54 | n03337140
55 | n03347037
56 | n03400231
57 | n03417042
58 | n03476684
59 | n03527444
60 | n03535780
61 | n03544143
62 | n03584254
63 | n03676483
64 | n03770439
65 | n03773504
66 | n03775546
67 | n03838899
68 | n03854065
69 | n03888605
70 | n03908618
71 | n03924679
72 | n03980874
73 | n03998194
74 | n04067472
75 | n04146614
76 | n04149813
77 | n04243546
78 | n04251144
79 | n04258138
80 | n04275548
81 | n04296562
82 | n04389033
83 | n04418357
84 | n04435653
85 | n04443257
86 | n04509417
87 | n04515003
88 | n04522168
89 | n04596742
90 | n04604644
91 | n04612504
92 | n06794110
93 | n07584110
94 | n07613480
95 | n07697537
96 | n07747607
97 | n09246464
98 | n09256479
99 | n13054560
100 | n13133613
101 |
--------------------------------------------------------------------------------
/hfdocs/source/index.mdx:
--------------------------------------------------------------------------------
1 | # timm
2 |
3 |
4 |
5 | `timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.
6 |
7 | It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use.
8 |
9 | Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library.
10 |
11 |
23 |
--------------------------------------------------------------------------------
/timm/layers/grn.py:
--------------------------------------------------------------------------------
1 | """ Global Response Normalization Module
2 |
3 | Based on the GRN layer presented in
4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
5 |
6 | This implementation
7 | * works for both NCHW and NHWC tensor layouts
8 | * uses affine param names matching existing torch norm layers
9 | * slightly improves eager mode performance via fused addcmul
10 |
11 | Hacked together by / Copyright 2023 Ross Wightman
12 | """
13 |
14 | import torch
15 | from torch import nn as nn
16 |
17 |
18 | class GlobalResponseNorm(nn.Module):
19 | """ Global Response Normalization layer
20 | """
21 | def __init__(
22 | self,
23 | dim: int,
24 | eps: float = 1e-6,
25 | channels_last: bool = True,
26 | device=None,
27 | dtype=None,
28 | ):
29 | dd = {'device': device, 'dtype': dtype}
30 | super().__init__()
31 | self.eps = eps
32 | if channels_last:
33 | self.spatial_dim = (1, 2)
34 | self.channel_dim = -1
35 | self.wb_shape = (1, 1, 1, -1)
36 | else:
37 | self.spatial_dim = (2, 3)
38 | self.channel_dim = 1
39 | self.wb_shape = (1, -1, 1, 1)
40 |
41 | self.weight = nn.Parameter(torch.zeros(dim, **dd))
42 | self.bias = nn.Parameter(torch.zeros(dim, **dd))
43 |
44 | def forward(self, x):
45 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
46 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
47 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)
48 |
--------------------------------------------------------------------------------
/timm/data/readers/img_extensions.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4 |
5 |
6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8 |
9 |
10 | def _set_extensions(extensions):
11 | global IMG_EXTENSIONS
12 | global _IMG_EXTENSIONS_SET
13 | dedupe = set() # NOTE de-duping tuple while keeping original order
14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15 | _IMG_EXTENSIONS_SET = set(extensions)
16 |
17 |
18 | def _valid_extension(x: str):
19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20 |
21 |
22 | def is_img_extension(ext):
23 | return ext in _IMG_EXTENSIONS_SET
24 |
25 |
26 | def get_img_extensions(as_set=False):
27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28 |
29 |
30 | def set_img_extensions(extensions):
31 | assert len(extensions)
32 | for x in extensions:
33 | assert _valid_extension(x)
34 | _set_extensions(extensions)
35 |
36 |
37 | def add_img_extensions(ext):
38 | if not isinstance(ext, (list, tuple, set)):
39 | ext = (ext,)
40 | for x in ext:
41 | assert _valid_extension(x)
42 | extensions = IMG_EXTENSIONS + tuple(ext)
43 | _set_extensions(extensions)
44 |
45 |
46 | def del_img_extensions(ext):
47 | if not isinstance(ext, (list, tuple, set)):
48 | ext = (ext,)
49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50 | _set_extensions(extensions)
51 |
--------------------------------------------------------------------------------
/timm/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LayerScale(nn.Module):
6 | """ LayerScale on tensors with channels in last-dim.
7 | """
8 | def __init__(
9 | self,
10 | dim: int,
11 | init_values: float = 1e-5,
12 | inplace: bool = False,
13 | device=None,
14 | dtype=None,
15 | ) -> None:
16 | super().__init__()
17 | self.init_values = init_values
18 | self.inplace = inplace
19 | self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
20 |
21 | self.reset_parameters()
22 |
23 | def reset_parameters(self):
24 | torch.nn.init.constant_(self.gamma, self.init_values)
25 |
26 | def forward(self, x: torch.Tensor) -> torch.Tensor:
27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
28 |
29 |
30 | class LayerScale2d(nn.Module):
31 | """ LayerScale for tensors with torch 2D NCHW layout.
32 | """
33 | def __init__(
34 | self,
35 | dim: int,
36 | init_values: float = 1e-5,
37 | inplace: bool = False,
38 | device=None,
39 | dtype=None,
40 | ):
41 | super().__init__()
42 | self.init_values = init_values
43 | self.inplace = inplace
44 | self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
45 |
46 | self.reset_parameters()
47 |
48 | def reset_parameters(self):
49 | torch.nn.init.constant_(self.gamma, self.init_values)
50 |
51 | def forward(self, x):
52 | gamma = self.gamma.view(1, -1, 1, 1)
53 | return x.mul_(gamma) if self.inplace else x * gamma
54 |
55 |
--------------------------------------------------------------------------------
/timm/loss/jsd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .cross_entropy import LabelSmoothingCrossEntropy
6 |
7 |
8 | class JsdCrossEntropy(nn.Module):
9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss
10 |
11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
13 | https://arxiv.org/abs/1912.02781
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
18 | super().__init__()
19 | self.num_splits = num_splits
20 | self.alpha = alpha
21 | if smoothing is not None and smoothing > 0:
22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
23 | else:
24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
25 |
26 | def __call__(self, output, target):
27 | split_size = output.shape[0] // self.num_splits
28 | assert split_size * self.num_splits == output.shape[0]
29 | logits_split = torch.split(output, split_size)
30 |
31 | # Cross-entropy is only computed on clean images
32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
33 | probs = [F.softmax(logits, dim=1) for logits in logits_split]
34 |
35 | # Clamp mixture distribution to avoid exploding KL divergence
36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
37 | loss += self.alpha * sum([F.kl_div(
38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
39 | return loss
40 |
--------------------------------------------------------------------------------
/timm/layers/create_conv2d.py:
--------------------------------------------------------------------------------
1 | """ Create Conv2d Factory Method
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | from .mixed_conv2d import MixedConv2d
7 | from .cond_conv2d import CondConv2d
8 | from .conv2d_same import create_conv2d_pad
9 |
10 |
11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
12 | """ Select a 2d convolution implementation based on arguments
13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
14 |
15 | Used extensively by EfficientNet, MobileNetv3 and related networks.
16 | """
17 | if isinstance(kernel_size, list):
18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
19 | if 'groups' in kwargs:
20 | groups = kwargs.pop('groups')
21 | if groups == in_channels:
22 | kwargs['depthwise'] = True
23 | else:
24 | assert groups == 1
25 | # We're going to use only lists for defining the MixedConv2d kernel groups,
26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
28 | else:
29 | depthwise = kwargs.pop('depthwise', False)
30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
31 | groups = in_channels if depthwise else kwargs.pop('groups', 1)
32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
34 | else:
35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
36 | return m
37 |
--------------------------------------------------------------------------------
/timm/utils/agc.py:
--------------------------------------------------------------------------------
1 | """ Adaptive Gradient Clipping
2 |
3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
4 |
5 | @article{brock2021high,
6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
7 | title={High-Performance Large-Scale Image Recognition Without Normalization},
8 | journal={arXiv preprint arXiv:},
9 | year={2021}
10 | }
11 |
12 | Code references:
13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
15 |
16 | Hacked together by / Copyright 2021 Ross Wightman
17 | """
18 | import torch
19 |
20 |
21 | def unitwise_norm(x, norm_type=2.0):
22 | if x.ndim <= 1:
23 | return x.norm(norm_type)
24 | else:
25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
26 | # might need special cases for other weights (possibly MHA) where this may not be true
27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
28 |
29 |
30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
31 | if isinstance(parameters, torch.Tensor):
32 | parameters = [parameters]
33 | for p in parameters:
34 | if p.grad is None:
35 | continue
36 | p_data = p.detach()
37 | g_data = p.grad.detach()
38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type)
40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
42 | p.grad.detach().copy_(new_grads)
43 |
--------------------------------------------------------------------------------
/timm/layers/grid.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 |
5 |
6 | def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
7 | """generate N-D grid in dimension order.
8 |
9 | The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
10 |
11 | That is, the statement
12 | [X1,X2,X3] = ndgrid(x1,x2,x3)
13 |
14 | produces the same result as
15 |
16 | [X2,X1,X3] = meshgrid(x2,x1,x3)
17 |
18 | This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
19 | torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
20 |
21 | """
22 | try:
23 | return torch.meshgrid(*tensors, indexing='ij')
24 | except TypeError:
25 | # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
26 | # the old behaviour of meshgrid was 'ij'
27 | return torch.meshgrid(*tensors)
28 |
29 |
30 | def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]:
31 | """generate N-D grid in spatial dim order.
32 |
33 | The meshgrid function is similar to ndgrid except that the order of the
34 | first two input and output arguments is switched.
35 |
36 | That is, the statement
37 |
38 | [X,Y,Z] = meshgrid(x,y,z)
39 | produces the same result as
40 |
41 | [Y,X,Z] = ndgrid(y,x,z)
42 | Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space,
43 | while ndgrid is better suited to multidimensional problems that aren't spatially based.
44 | """
45 |
46 | # NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have
47 | # capability of generating grid in xy order before then.
48 | return torch.meshgrid(*tensors, indexing='xy')
49 |
50 |
--------------------------------------------------------------------------------
/timm/layers/median_pool.py:
--------------------------------------------------------------------------------
1 | """ Median Pool
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .helpers import to_2tuple, to_4tuple
7 |
8 |
9 | class MedianPool2d(nn.Module):
10 | """ Median pool (usable as median filter when stride=1) module.
11 |
12 | Args:
13 | kernel_size: size of pooling kernel, int or 2-tuple
14 | stride: pool stride, int or 2-tuple
15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16 | same: override padding and enforce same padding, boolean
17 | """
18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19 | super().__init__()
20 | self.k = to_2tuple(kernel_size)
21 | self.stride = to_2tuple(stride)
22 | self.padding = to_4tuple(padding) # convert to l, r, t, b
23 | self.same = same
24 |
25 | def _padding(self, x):
26 | if self.same:
27 | ih, iw = x.size()[2:]
28 | if ih % self.stride[0] == 0:
29 | ph = max(self.k[0] - self.stride[0], 0)
30 | else:
31 | ph = max(self.k[0] - (ih % self.stride[0]), 0)
32 | if iw % self.stride[1] == 0:
33 | pw = max(self.k[1] - self.stride[1], 0)
34 | else:
35 | pw = max(self.k[1] - (iw % self.stride[1]), 0)
36 | pl = pw // 2
37 | pr = pw - pl
38 | pt = ph // 2
39 | pb = ph - pt
40 | padding = (pl, pr, pt, pb)
41 | else:
42 | padding = self.padding
43 | return padding
44 |
45 | def forward(self, x):
46 | x = F.pad(x, self._padding(x), mode='reflect')
47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49 | return x
50 |
--------------------------------------------------------------------------------
/timm/utils/decay_batch.py:
--------------------------------------------------------------------------------
1 | """ Batch size decay and retry helpers.
2 |
3 | Copyright 2022 Ross Wightman
4 | """
5 | import math
6 |
7 |
8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
9 | """ power of two batch-size decay with intra steps
10 |
11 | Decay by stepping between powers of 2:
12 | * determine power-of-2 floor of current batch size (base batch size)
13 | * divide above value by num_intra_steps to determine step size
14 | * floor batch_size to nearest multiple of step_size (from base batch size)
15 | Examples:
16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
20 | """
21 | if batch_size <= 1:
22 | # return 0 for stopping value so easy to use in loop
23 | return 0
24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
25 | step_size = max(base_batch_size // num_intra_steps, 1)
26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
27 | if no_odd and batch_size % 2:
28 | batch_size -= 1
29 | return batch_size
30 |
31 |
32 | def check_batch_size_retry(error_str):
33 | """ check failure error string for conditions where batch decay retry should not be attempted
34 | """
35 | error_str = error_str.lower()
36 | if 'required rank' in error_str:
37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's
38 | # not compatible with channels_last memory format.
39 | return False
40 | if 'illegal' in error_str:
41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state
42 | return False
43 | return True
44 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_r_indices.txt:
--------------------------------------------------------------------------------
1 | 1
2 | 2
3 | 4
4 | 6
5 | 8
6 | 9
7 | 11
8 | 13
9 | 22
10 | 23
11 | 26
12 | 29
13 | 31
14 | 39
15 | 47
16 | 63
17 | 71
18 | 76
19 | 79
20 | 84
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 100
27 | 105
28 | 107
29 | 113
30 | 122
31 | 125
32 | 130
33 | 132
34 | 144
35 | 145
36 | 147
37 | 148
38 | 150
39 | 151
40 | 155
41 | 160
42 | 161
43 | 162
44 | 163
45 | 171
46 | 172
47 | 178
48 | 187
49 | 195
50 | 199
51 | 203
52 | 207
53 | 208
54 | 219
55 | 231
56 | 232
57 | 234
58 | 235
59 | 242
60 | 245
61 | 247
62 | 250
63 | 251
64 | 254
65 | 259
66 | 260
67 | 263
68 | 265
69 | 267
70 | 269
71 | 276
72 | 277
73 | 281
74 | 288
75 | 289
76 | 291
77 | 292
78 | 293
79 | 296
80 | 299
81 | 301
82 | 308
83 | 309
84 | 310
85 | 311
86 | 314
87 | 315
88 | 319
89 | 323
90 | 327
91 | 330
92 | 334
93 | 335
94 | 337
95 | 338
96 | 340
97 | 341
98 | 344
99 | 347
100 | 353
101 | 355
102 | 361
103 | 362
104 | 365
105 | 366
106 | 367
107 | 368
108 | 372
109 | 388
110 | 390
111 | 393
112 | 397
113 | 401
114 | 407
115 | 413
116 | 414
117 | 425
118 | 428
119 | 430
120 | 435
121 | 437
122 | 441
123 | 447
124 | 448
125 | 457
126 | 462
127 | 463
128 | 469
129 | 470
130 | 471
131 | 472
132 | 476
133 | 483
134 | 487
135 | 515
136 | 546
137 | 555
138 | 558
139 | 570
140 | 579
141 | 583
142 | 587
143 | 593
144 | 594
145 | 596
146 | 609
147 | 613
148 | 617
149 | 621
150 | 629
151 | 637
152 | 657
153 | 658
154 | 701
155 | 717
156 | 724
157 | 763
158 | 768
159 | 774
160 | 776
161 | 779
162 | 780
163 | 787
164 | 805
165 | 812
166 | 815
167 | 820
168 | 824
169 | 833
170 | 847
171 | 852
172 | 866
173 | 875
174 | 883
175 | 889
176 | 895
177 | 907
178 | 928
179 | 931
180 | 932
181 | 933
182 | 934
183 | 936
184 | 937
185 | 943
186 | 945
187 | 947
188 | 948
189 | 949
190 | 951
191 | 953
192 | 954
193 | 957
194 | 963
195 | 965
196 | 967
197 | 980
198 | 981
199 | 983
200 | 988
201 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_a_indices.txt:
--------------------------------------------------------------------------------
1 | 6
2 | 11
3 | 13
4 | 15
5 | 17
6 | 22
7 | 23
8 | 27
9 | 30
10 | 37
11 | 39
12 | 42
13 | 47
14 | 50
15 | 57
16 | 70
17 | 71
18 | 76
19 | 79
20 | 89
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 105
27 | 107
28 | 108
29 | 110
30 | 113
31 | 124
32 | 125
33 | 130
34 | 132
35 | 143
36 | 144
37 | 150
38 | 151
39 | 207
40 | 234
41 | 235
42 | 254
43 | 277
44 | 283
45 | 287
46 | 291
47 | 295
48 | 298
49 | 301
50 | 306
51 | 307
52 | 308
53 | 309
54 | 310
55 | 311
56 | 313
57 | 314
58 | 315
59 | 317
60 | 319
61 | 323
62 | 324
63 | 326
64 | 327
65 | 330
66 | 334
67 | 335
68 | 336
69 | 347
70 | 361
71 | 363
72 | 372
73 | 378
74 | 386
75 | 397
76 | 400
77 | 401
78 | 402
79 | 404
80 | 407
81 | 411
82 | 416
83 | 417
84 | 420
85 | 425
86 | 428
87 | 430
88 | 437
89 | 438
90 | 445
91 | 456
92 | 457
93 | 461
94 | 462
95 | 470
96 | 472
97 | 483
98 | 486
99 | 488
100 | 492
101 | 496
102 | 514
103 | 516
104 | 528
105 | 530
106 | 539
107 | 542
108 | 543
109 | 549
110 | 552
111 | 557
112 | 561
113 | 562
114 | 569
115 | 572
116 | 573
117 | 575
118 | 579
119 | 589
120 | 606
121 | 607
122 | 609
123 | 614
124 | 626
125 | 627
126 | 640
127 | 641
128 | 642
129 | 643
130 | 658
131 | 668
132 | 677
133 | 682
134 | 684
135 | 687
136 | 701
137 | 704
138 | 719
139 | 736
140 | 746
141 | 749
142 | 752
143 | 758
144 | 763
145 | 765
146 | 768
147 | 773
148 | 774
149 | 776
150 | 779
151 | 780
152 | 786
153 | 792
154 | 797
155 | 802
156 | 803
157 | 804
158 | 813
159 | 815
160 | 820
161 | 823
162 | 831
163 | 833
164 | 835
165 | 839
166 | 845
167 | 847
168 | 850
169 | 859
170 | 862
171 | 870
172 | 879
173 | 880
174 | 888
175 | 890
176 | 897
177 | 900
178 | 907
179 | 913
180 | 924
181 | 932
182 | 933
183 | 934
184 | 937
185 | 943
186 | 945
187 | 947
188 | 951
189 | 954
190 | 956
191 | 957
192 | 959
193 | 971
194 | 972
195 | 980
196 | 981
197 | 984
198 | 986
199 | 987
200 | 988
201 |
--------------------------------------------------------------------------------
/timm/data/real_labels.py:
--------------------------------------------------------------------------------
1 | """ Real labels evaluator for ImageNet
2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import os
8 | import json
9 | import numpy as np
10 | import pkgutil
11 |
12 |
13 | class RealLabelsImagenet:
14 |
15 | def __init__(self, filenames, real_json=None, topk=(1, 5)):
16 | if real_json is not None:
17 | with open(real_json) as real_labels:
18 | real_labels = json.load(real_labels)
19 | else:
20 | real_labels = json.loads(
21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
23 | self.real_labels = real_labels
24 | self.filenames = filenames
25 | assert len(self.filenames) == len(self.real_labels)
26 | self.topk = topk
27 | self.is_correct = {k: [] for k in topk}
28 | self.sample_idx = 0
29 |
30 | def add_result(self, output):
31 | maxk = max(self.topk)
32 | _, pred_batch = output.topk(maxk, 1, True, True)
33 | pred_batch = pred_batch.cpu().numpy()
34 | for pred in pred_batch:
35 | filename = self.filenames[self.sample_idx]
36 | filename = os.path.basename(filename)
37 | if self.real_labels[filename]:
38 | for k in self.topk:
39 | self.is_correct[k].append(
40 | any([p in self.real_labels[filename] for p in pred[:k]]))
41 | self.sample_idx += 1
42 |
43 | def get_accuracy(self, k=None):
44 | if k is None:
45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
46 | else:
47 | return float(np.mean(self.is_correct[k])) * 100
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # PyCharm
101 | .idea
102 |
103 | output/
104 |
105 | # PyTorch weights
106 | *.tar
107 | *.pth
108 | *.pt
109 | *.torch
110 | *.gz
111 | Untitled.ipynb
112 | Testing notebook.ipynb
113 |
114 | # Root dir exclusions
115 | /*.csv
116 | /*.yaml
117 | /*.json
118 | /*.jpg
119 | /*.png
120 | /*.zip
121 | /*.tar.*
--------------------------------------------------------------------------------
/timm/scheduler/step_lr.py:
--------------------------------------------------------------------------------
1 | """ Step Scheduler
2 |
3 | Basic step LR schedule with warmup, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import math
8 | import torch
9 | from typing import List
10 |
11 |
12 | from .scheduler import Scheduler
13 |
14 |
15 | class StepLRScheduler(Scheduler):
16 | """
17 | """
18 |
19 | def __init__(
20 | self,
21 | optimizer: torch.optim.Optimizer,
22 | decay_t: float,
23 | decay_rate: float = 1.,
24 | warmup_t=0,
25 | warmup_lr_init=0,
26 | warmup_prefix=True,
27 | t_in_epochs=True,
28 | noise_range_t=None,
29 | noise_pct=0.67,
30 | noise_std=1.0,
31 | noise_seed=42,
32 | initialize=True,
33 | ) -> None:
34 | super().__init__(
35 | optimizer,
36 | param_group_field="lr",
37 | t_in_epochs=t_in_epochs,
38 | noise_range_t=noise_range_t,
39 | noise_pct=noise_pct,
40 | noise_std=noise_std,
41 | noise_seed=noise_seed,
42 | initialize=initialize,
43 | )
44 |
45 | self.decay_t = decay_t
46 | self.decay_rate = decay_rate
47 | self.warmup_t = warmup_t
48 | self.warmup_lr_init = warmup_lr_init
49 | self.warmup_prefix = warmup_prefix
50 | if self.warmup_t:
51 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
52 | super().update_groups(self.warmup_lr_init)
53 | else:
54 | self.warmup_steps = [1 for _ in self.base_values]
55 |
56 | def _get_lr(self, t: int) -> List[float]:
57 | if t < self.warmup_t:
58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
59 | else:
60 | if self.warmup_prefix:
61 | t = t - self.warmup_t
62 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
63 | return lrs
64 |
--------------------------------------------------------------------------------
/hfdocs/source/hf_hub.mdx:
--------------------------------------------------------------------------------
1 | # Sharing and Loading Models From the Hugging Face Hub
2 |
3 | The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub.
4 |
5 | In this short guide, we'll see how to:
6 | 1. Share a `timm` model on the Hub
7 | 2. How to load that model back from the Hub
8 |
9 | ## Authenticating
10 |
11 | First, you'll need to make sure you have the `huggingface_hub` package installed.
12 |
13 | ```bash
14 | pip install huggingface_hub
15 | ```
16 |
17 | Then, you'll need to authenticate yourself. You can do this by running the following command:
18 |
19 | ```bash
20 | huggingface-cli login
21 | ```
22 |
23 | Or, if you're using a notebook, you can use the `notebook_login` helper:
24 |
25 | ```py
26 | >>> from huggingface_hub import notebook_login
27 | >>> notebook_login()
28 | ```
29 |
30 | ## Sharing a Model
31 |
32 | ```py
33 | >>> import timm
34 | >>> model = timm.create_model('resnet18', pretrained=True, num_classes=4)
35 | ```
36 |
37 | Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial.
38 |
39 | Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function.
40 |
41 | ```py
42 | >>> model_cfg = dict(label_names=['a', 'b', 'c', 'd'])
43 | >>> timm.models.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg)
44 | ```
45 |
46 | Running the above would push the model to `/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code!
47 |
48 | ## Loading a Model
49 |
50 | Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub.
51 |
52 | ```py
53 | >>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True)
54 | ```
55 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | from .reader_image_folder import ReaderImageFolder
5 | from .reader_image_in_tar import ReaderImageInTar
6 |
7 |
8 | def create_reader(
9 | name: str,
10 | root: Optional[str] = None,
11 | split: str = 'train',
12 | **kwargs,
13 | ):
14 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
15 | name = name.lower()
16 | name = name.split('/', 1)
17 | prefix = ''
18 | if len(name) > 1:
19 | prefix = name[0]
20 | name = name[-1]
21 |
22 | # FIXME the additional features are only supported by ReaderHfds for now.
23 | additional_features = kwargs.pop("additional_features", None)
24 |
25 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
26 | # explicitly select other options shortly
27 | if prefix == 'hfds':
28 | from .reader_hfds import ReaderHfds # defer Hf datasets import
29 | reader = ReaderHfds(name=name, root=root, split=split, additional_features=additional_features, **kwargs)
30 | elif prefix == 'hfids':
31 | from .reader_hfids import ReaderHfids # defer HF datasets import
32 | reader = ReaderHfids(name=name, root=root, split=split, **kwargs)
33 | elif prefix == 'tfds':
34 | from .reader_tfds import ReaderTfds # defer tensorflow import
35 | reader = ReaderTfds(name=name, root=root, split=split, **kwargs)
36 | elif prefix == 'wds':
37 | from .reader_wds import ReaderWds
38 | kwargs.pop('download', False)
39 | reader = ReaderWds(root=root, name=name, split=split, **kwargs)
40 | else:
41 | assert os.path.exists(root)
42 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
43 | # FIXME support split here or in reader?
44 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
45 | reader = ReaderImageInTar(root, **kwargs)
46 | else:
47 | reader = ReaderImageFolder(root, **kwargs)
48 | return reader
49 |
--------------------------------------------------------------------------------
/timm/layers/test_time_pool.py:
--------------------------------------------------------------------------------
1 | """ Test Time Pooling (Average-Max Pool)
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | import logging
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | class TestTimePoolHead(nn.Module):
17 | def __init__(self, base, original_pool=7):
18 | super().__init__()
19 | self.base = base
20 | self.original_pool = original_pool
21 | base_fc = self.base.get_classifier()
22 | if isinstance(base_fc, nn.Conv2d):
23 | self.fc = base_fc
24 | else:
25 | self.fc = nn.Conv2d(
26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
29 | self.base.reset_classifier(0) # delete original fc layer
30 |
31 | def forward(self, x):
32 | x = self.base.forward_features(x)
33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
34 | x = self.fc(x)
35 | x = adaptive_avgmax_pool2d(x, 1)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 | def apply_test_time_pool(model, config, use_test_size=False):
40 | test_time_pool = False
41 | if not hasattr(model, 'default_cfg') or not model.default_cfg:
42 | return model, False
43 | if use_test_size and 'test_input_size' in model.default_cfg:
44 | df_input_size = model.default_cfg['test_input_size']
45 | else:
46 | df_input_size = model.default_cfg['input_size']
47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
49 | (str(config['input_size'][-2:]), str(df_input_size[-2:])))
50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
51 | test_time_pool = True
52 | return model, test_time_pool
53 |
--------------------------------------------------------------------------------
/timm/scheduler/multistep_lr.py:
--------------------------------------------------------------------------------
1 | """ MultiStep LR Scheduler
2 |
3 | Basic multi step LR schedule with warmup, noise.
4 | """
5 | import torch
6 | import bisect
7 | from timm.scheduler.scheduler import Scheduler
8 | from typing import List
9 |
10 | class MultiStepLRScheduler(Scheduler):
11 | """
12 | """
13 |
14 | def __init__(
15 | self,
16 | optimizer: torch.optim.Optimizer,
17 | decay_t: List[int],
18 | decay_rate: float = 1.,
19 | warmup_t=0,
20 | warmup_lr_init=0,
21 | warmup_prefix=True,
22 | t_in_epochs=True,
23 | noise_range_t=None,
24 | noise_pct=0.67,
25 | noise_std=1.0,
26 | noise_seed=42,
27 | initialize=True,
28 | ) -> None:
29 | super().__init__(
30 | optimizer,
31 | param_group_field="lr",
32 | t_in_epochs=t_in_epochs,
33 | noise_range_t=noise_range_t,
34 | noise_pct=noise_pct,
35 | noise_std=noise_std,
36 | noise_seed=noise_seed,
37 | initialize=initialize,
38 | )
39 |
40 | self.decay_t = decay_t
41 | self.decay_rate = decay_rate
42 | self.warmup_t = warmup_t
43 | self.warmup_lr_init = warmup_lr_init
44 | self.warmup_prefix = warmup_prefix
45 | if self.warmup_t:
46 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
47 | super().update_groups(self.warmup_lr_init)
48 | else:
49 | self.warmup_steps = [1 for _ in self.base_values]
50 |
51 | def get_curr_decay_steps(self, t):
52 | # find where in the array t goes,
53 | # assumes self.decay_t is sorted
54 | return bisect.bisect_right(self.decay_t, t + 1)
55 |
56 | def _get_lr(self, t: int) -> List[float]:
57 | if t < self.warmup_t:
58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
59 | else:
60 | if self.warmup_prefix:
61 | t = t - self.warmup_t
62 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
63 | return lrs
64 |
--------------------------------------------------------------------------------
/hfdocs/source/installation.mdx:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**.
4 |
5 | ## Virtual Environment
6 |
7 | You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts.
8 |
9 | 1. Create and navigate to your project directory:
10 |
11 | ```bash
12 | mkdir ~/my-project
13 | cd ~/my-project
14 | ```
15 |
16 | 2. Start a virtual environment inside your directory:
17 |
18 | ```bash
19 | python -m venv .env
20 | ```
21 |
22 | 3. Activate and deactivate the virtual environment with the following commands:
23 |
24 | ```bash
25 | # Activate the virtual environment
26 | source .env/bin/activate
27 |
28 | # Deactivate the virtual environment
29 | source .env/bin/deactivate
30 | ```
31 |
32 | Once you've created your virtual environment, you can install `timm` in it.
33 |
34 | ## Using pip
35 |
36 | The most straightforward way to install `timm` is with pip:
37 |
38 | ```bash
39 | pip install timm
40 | ```
41 |
42 | Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version:
43 |
44 | ```bash
45 | pip install git+https://github.com/rwightman/pytorch-image-models.git
46 | ```
47 |
48 | Run the following command to check if `timm` has been properly installed:
49 |
50 | ```bash
51 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
52 | ```
53 |
54 | This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output:
55 |
56 | ```python
57 | ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384']
58 | ```
59 |
60 | ## From Source
61 |
62 | Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands:
63 |
64 | ```bash
65 | git clone https://github.com/rwightman/pytorch-image-models.git
66 | cd pytorch-image-models
67 | pip install -e .
68 | ```
69 |
70 | Again, you can check if `timm` was properly installed with the following command:
71 |
72 | ```bash
73 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
74 | ```
75 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["pdm-backend"]
3 | build-backend = "pdm.backend"
4 |
5 | [project]
6 | name = "timm"
7 | authors = [
8 | {name = "Ross Wightman", email = "ross@huggingface.co"},
9 | ]
10 | description = "PyTorch Image Models"
11 | readme = "README.md"
12 | requires-python = ">=3.8"
13 | keywords = ["pytorch", "image-classification"]
14 | license = {text = "Apache-2.0"}
15 | classifiers = [
16 | 'Development Status :: 5 - Production/Stable',
17 | 'Intended Audience :: Education',
18 | 'Intended Audience :: Science/Research',
19 | 'License :: OSI Approved :: Apache Software License',
20 | 'Programming Language :: Python :: 3.8',
21 | 'Programming Language :: Python :: 3.9',
22 | 'Programming Language :: Python :: 3.10',
23 | 'Programming Language :: Python :: 3.11',
24 | 'Programming Language :: Python :: 3.12',
25 | 'Topic :: Scientific/Engineering',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'Topic :: Software Development',
28 | 'Topic :: Software Development :: Libraries',
29 | 'Topic :: Software Development :: Libraries :: Python Modules',
30 | ]
31 | dependencies = [
32 | 'torch',
33 | 'torchvision',
34 | 'pyyaml',
35 | 'huggingface_hub',
36 | 'safetensors',
37 | ]
38 | dynamic = ["version"]
39 |
40 | [project.urls]
41 | homepage = "https://github.com/huggingface/pytorch-image-models"
42 | documentation = "https://huggingface.co/docs/timm/en/index"
43 | repository = "https://github.com/huggingface/pytorch-image-models"
44 |
45 | [tool.pdm.dev-dependencies]
46 | test = [
47 | 'pytest',
48 | 'pytest-timeout',
49 | 'pytest-xdist',
50 | 'pytest-forked',
51 | 'expecttest',
52 | ]
53 |
54 | [tool.pdm.version]
55 | source = "file"
56 | path = "timm/version.py"
57 |
58 | [tool.pytest.ini_options]
59 | testpaths = ['tests']
60 | markers = [
61 | "base: marker for model tests using the basic setup",
62 | "cfg: marker for model tests checking the config",
63 | "torchscript: marker for model tests using torchscript",
64 | "features: marker for model tests checking feature extraction",
65 | "fxforward: marker for model tests using torch fx (only forward)",
66 | "fxbackward: marker for model tests using torch fx (only backward)",
67 | ]
--------------------------------------------------------------------------------
/timm/layers/create_norm.py:
--------------------------------------------------------------------------------
1 | """ Norm Layer Factory
2 |
3 | Create norm modules by string (to mirror create_act and creat_norm-act fns)
4 |
5 | Copyright 2022 Ross Wightman
6 | """
7 | import functools
8 | import types
9 | from typing import Type
10 |
11 | import torch.nn as nn
12 |
13 | from .norm import (
14 | GroupNorm,
15 | GroupNorm1,
16 | LayerNorm,
17 | LayerNorm2d,
18 | LayerNormFp32,
19 | LayerNorm2dFp32,
20 | RmsNorm,
21 | RmsNorm2d,
22 | RmsNormFp32,
23 | RmsNorm2dFp32,
24 | SimpleNorm,
25 | SimpleNorm2d,
26 | SimpleNormFp32,
27 | SimpleNorm2dFp32,
28 | )
29 | from torchvision.ops.misc import FrozenBatchNorm2d
30 |
31 | _NORM_MAP = dict(
32 | batchnorm=nn.BatchNorm2d,
33 | batchnorm2d=nn.BatchNorm2d,
34 | batchnorm1d=nn.BatchNorm1d,
35 | groupnorm=GroupNorm,
36 | groupnorm1=GroupNorm1,
37 | layernorm=LayerNorm,
38 | layernorm2d=LayerNorm2d,
39 | layernormfp32=LayerNormFp32,
40 | layernorm2dfp32=LayerNorm2dFp32,
41 | rmsnorm=RmsNorm,
42 | rmsnorm2d=RmsNorm2d,
43 | rmsnormfp32=RmsNormFp32,
44 | rmsnorm2dfp32=RmsNorm2dFp32,
45 | simplenorm=SimpleNorm,
46 | simplenorm2d=SimpleNorm2d,
47 | simplenormfp32=SimpleNormFp32,
48 | simplenorm2dfp32=SimpleNorm2dFp32,
49 | frozenbatchnorm2d=FrozenBatchNorm2d,
50 | )
51 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
52 |
53 |
54 | def create_norm_layer(layer_name, num_features, **kwargs):
55 | layer = get_norm_layer(layer_name)
56 | layer_instance = layer(num_features, **kwargs)
57 | return layer_instance
58 |
59 |
60 | def get_norm_layer(norm_layer):
61 | if norm_layer is None:
62 | return None
63 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
64 | norm_kwargs = {}
65 |
66 | # unbind partial fn, so args can be rebound later
67 | if isinstance(norm_layer, functools.partial):
68 | norm_kwargs.update(norm_layer.keywords)
69 | norm_layer = norm_layer.func
70 |
71 | if isinstance(norm_layer, str):
72 | if not norm_layer:
73 | return None
74 | layer_name = norm_layer.replace('_', '').lower()
75 | norm_layer = _NORM_MAP[layer_name]
76 | else:
77 | norm_layer = norm_layer
78 |
79 | if norm_kwargs:
80 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
81 | return norm_layer
82 |
--------------------------------------------------------------------------------
/timm/utils/jit.py:
--------------------------------------------------------------------------------
1 | """ JIT scripting/tracing utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import os
6 |
7 | import torch
8 |
9 |
10 | def set_jit_legacy():
11 | """ Set JIT executor to legacy w/ support for op fusion
12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
13 | in the JIT executor. These API are not supported so could change.
14 | """
15 | #
16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
17 | torch._C._jit_set_profiling_executor(False)
18 | torch._C._jit_set_profiling_mode(False)
19 | torch._C._jit_override_can_fuse_on_gpu(True)
20 | #torch._C._jit_set_texpr_fuser_enabled(True)
21 |
22 |
23 | def set_jit_fuser(fuser):
24 | if fuser == "te":
25 | # default fuser should be == 'te'
26 | torch._C._jit_set_profiling_executor(True)
27 | torch._C._jit_set_profiling_mode(True)
28 | torch._C._jit_override_can_fuse_on_cpu(False)
29 | torch._C._jit_override_can_fuse_on_gpu(True)
30 | torch._C._jit_set_texpr_fuser_enabled(True)
31 | try:
32 | torch._C._jit_set_nvfuser_enabled(False)
33 | except Exception:
34 | pass
35 | elif fuser == "old" or fuser == "legacy":
36 | torch._C._jit_set_profiling_executor(False)
37 | torch._C._jit_set_profiling_mode(False)
38 | torch._C._jit_override_can_fuse_on_gpu(True)
39 | torch._C._jit_set_texpr_fuser_enabled(False)
40 | try:
41 | torch._C._jit_set_nvfuser_enabled(False)
42 | except Exception:
43 | pass
44 | elif fuser == "nvfuser" or fuser == "nvf":
45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
48 | torch._C._jit_set_texpr_fuser_enabled(False)
49 | torch._C._jit_set_profiling_executor(True)
50 | torch._C._jit_set_profiling_mode(True)
51 | torch._C._jit_can_fuse_on_cpu()
52 | torch._C._jit_can_fuse_on_gpu()
53 | torch._C._jit_override_can_fuse_on_cpu(False)
54 | torch._C._jit_override_can_fuse_on_gpu(False)
55 | torch._C._jit_set_nvfuser_guard_mode(True)
56 | torch._C._jit_set_nvfuser_enabled(True)
57 | else:
58 | assert False, f"Invalid jit fuser ({fuser})"
59 |
--------------------------------------------------------------------------------
/timm/layers/mixed_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Mixed Convolution
2 |
3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | from typing import List, Union
8 |
9 | import torch
10 | from torch import nn as nn
11 |
12 | from .conv2d_same import create_conv2d_pad
13 |
14 |
15 | def _split_channels(num_chan, num_groups):
16 | split = [num_chan // num_groups for _ in range(num_groups)]
17 | split[0] += num_chan - sum(split)
18 | return split
19 |
20 |
21 | class MixedConv2d(nn.ModuleDict):
22 | """ Mixed Grouped Convolution
23 |
24 | Based on MDConv and GroupedConv in MixNet impl:
25 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
26 | """
27 | def __init__(
28 | self,
29 | in_channels: int,
30 | out_channels: int,
31 | kernel_size: Union[int, List[int]] = 3,
32 | stride: int = 1,
33 | padding: str = '',
34 | dilation: int = 1,
35 | depthwise: bool = False,
36 | **kwargs
37 | ):
38 | super().__init__()
39 |
40 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
41 | num_groups = len(kernel_size)
42 | in_splits = _split_channels(in_channels, num_groups)
43 | out_splits = _split_channels(out_channels, num_groups)
44 | self.in_channels = sum(in_splits)
45 | self.out_channels = sum(out_splits)
46 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
47 | conv_groups = in_ch if depthwise else 1
48 | # use add_module to keep key space clean
49 | self.add_module(
50 | str(idx),
51 | create_conv2d_pad(
52 | in_ch,
53 | out_ch,
54 | k,
55 | stride=stride,
56 | padding=padding,
57 | dilation=dilation,
58 | groups=conv_groups,
59 | **kwargs,
60 | )
61 | )
62 | self.splits = in_splits
63 |
64 | def forward(self, x):
65 | x_split = torch.split(x, self.splits, 1)
66 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
67 | x = torch.cat(x_out, 1)
68 | return x
69 |
--------------------------------------------------------------------------------
/timm/utils/cuda.py:
--------------------------------------------------------------------------------
1 | """ CUDA / AMP utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 |
7 | try:
8 | from apex import amp
9 | has_apex = True
10 | except ImportError:
11 | amp = None
12 | has_apex = False
13 |
14 | from .clip_grad import dispatch_clip_grad
15 |
16 |
17 | class ApexScaler:
18 | state_dict_key = "amp"
19 |
20 | def __call__(
21 | self,
22 | loss,
23 | optimizer,
24 | clip_grad=None,
25 | clip_mode='norm',
26 | parameters=None,
27 | create_graph=False,
28 | need_update=True,
29 | ):
30 | with amp.scale_loss(loss, optimizer) as scaled_loss:
31 | scaled_loss.backward(create_graph=create_graph)
32 | if need_update:
33 | if clip_grad is not None:
34 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
35 | optimizer.step()
36 |
37 | def state_dict(self):
38 | if 'state_dict' in amp.__dict__:
39 | return amp.state_dict()
40 |
41 | def load_state_dict(self, state_dict):
42 | if 'load_state_dict' in amp.__dict__:
43 | amp.load_state_dict(state_dict)
44 |
45 |
46 | class NativeScaler:
47 | state_dict_key = "amp_scaler"
48 |
49 | def __init__(self, device='cuda'):
50 | try:
51 | self._scaler = torch.amp.GradScaler(device=device)
52 | except (AttributeError, TypeError) as e:
53 | self._scaler = torch.cuda.amp.GradScaler()
54 |
55 | def __call__(
56 | self,
57 | loss,
58 | optimizer,
59 | clip_grad=None,
60 | clip_mode='norm',
61 | parameters=None,
62 | create_graph=False,
63 | need_update=True,
64 | ):
65 | self._scaler.scale(loss).backward(create_graph=create_graph)
66 | if need_update:
67 | if clip_grad is not None:
68 | assert parameters is not None
69 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
70 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
71 | self._scaler.step(optimizer)
72 | self._scaler.update()
73 |
74 | def state_dict(self):
75 | return self._scaler.state_dict()
76 |
77 | def load_state_dict(self, state_dict):
78 | self._scaler.load_state_dict(state_dict)
79 |
--------------------------------------------------------------------------------
/timm/layers/_fx.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Optional, Union, Tuple, Type
2 |
3 | import torch
4 | from torch import nn
5 |
6 | try:
7 | # NOTE we wrap torchvision fns to use timm leaf / no trace definitions
8 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
9 | from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
10 | has_fx_feature_extraction = True
11 | except ImportError:
12 | has_fx_feature_extraction = False
13 |
14 |
15 | __all__ = [
16 | 'register_notrace_module',
17 | 'is_notrace_module',
18 | 'get_notrace_modules',
19 | 'register_notrace_function',
20 | 'is_notrace_function',
21 | 'get_notrace_functions',
22 | 'create_feature_extractor',
23 | 'get_graph_node_names',
24 | ]
25 |
26 | # modules to treat as leafs when tracing
27 | _leaf_modules = set()
28 |
29 |
30 | def register_notrace_module(module: Type[nn.Module]):
31 | """
32 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
33 | """
34 | _leaf_modules.add(module)
35 | return module
36 |
37 |
38 | def is_notrace_module(module: Type[nn.Module]):
39 | return module in _leaf_modules
40 |
41 |
42 | def get_notrace_modules():
43 | return list(_leaf_modules)
44 |
45 |
46 | # Functions we want to autowrap (treat them as leaves)
47 | _autowrap_functions = set()
48 |
49 |
50 | def register_notrace_function(name_or_fn):
51 | _autowrap_functions.add(name_or_fn)
52 | return name_or_fn
53 |
54 |
55 | def is_notrace_function(func: Callable):
56 | return func in _autowrap_functions
57 |
58 |
59 | def get_notrace_functions():
60 | return list(_autowrap_functions)
61 |
62 |
63 | def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
64 | return _get_graph_node_names(
65 | model,
66 | tracer_kwargs={
67 | 'leaf_modules': list(_leaf_modules),
68 | 'autowrap_functions': list(_autowrap_functions)
69 | }
70 | )
71 |
72 |
73 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
74 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
75 | return _create_feature_extractor(
76 | model, return_nodes,
77 | tracer_kwargs={
78 | 'leaf_modules': list(_leaf_modules),
79 | 'autowrap_functions': list(_autowrap_functions)
80 | }
81 | )
--------------------------------------------------------------------------------
/UPGRADING.md:
--------------------------------------------------------------------------------
1 | # Upgrading from previous versions
2 |
3 | I generally try to maintain code interface and especially model weight compatibility across many `timm` versions. Sometimes there are exceptions.
4 |
5 | ## Checkpoint remapping
6 |
7 | Pretrained weight remapping is handled by `checkpoint_filter_fn` in a model implementation module. This remaps old pretrained checkpoints to new, and also 3rd party (original) checkpoints to `timm` format if the model was modified when brought into `timm`.
8 |
9 | The `checkpoint_filter_fn` is automatically called when loading pretrained weights via `pretrained=True`, but they can be called manually if you call the fn directly with the current model instance and old state dict.
10 |
11 | ## Upgrading from 0.6 and earlier
12 |
13 | Many changes were made since the 0.6.x stable releases. They were previewed in 0.8.x dev releases but not everyone transitioned.
14 | * `timm.models.layers` moved to `timm.layers`:
15 | * `from timm.models.layers import name` will still work via deprecation mapping (but please transition to `timm.layers`).
16 | * `import timm.models.layers.module` or `from timm.models.layers.module import name` needs to be changed now.
17 | * Builder, helper, non-model modules in `timm.models` have a `_` prefix added, ie `timm.models.helpers` -> `timm.models._helpers`, there are temporary deprecation mapping files but those will be removed.
18 | * All models now support `architecture.pretrained_tag` naming (ex `resnet50.rsb_a1`).
19 | * The pretrained_tag is the specific weight variant (different head) for the architecture.
20 | * Using only `architecture` defaults to the first weights in the default_cfgs for that model architecture.
21 | * In adding pretrained tags, many model names that existed to differentiate were renamed to use the tag (ex: `vit_base_patch16_224_in21k` -> `vit_base_patch16_224.augreg_in21k`). There are deprecation mappings for these.
22 | * A number of models had their checkpoints remapped to match architecture changes needed to better support `features_only=True`, there are `checkpoint_filter_fn` methods in any model module that was remapped. These can be passed to `timm.models.load_checkpoint(..., filter_fn=timm.models.swin_transformer_v2.checkpoint_filter_fn)` to remap your existing checkpoint.
23 | * The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
24 | * Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
25 |
--------------------------------------------------------------------------------
/timm/data/dataset_info.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional, Union
3 |
4 |
5 | class DatasetInfo(ABC):
6 |
7 | def __init__(self):
8 | pass
9 |
10 | @abstractmethod
11 | def num_classes(self):
12 | pass
13 |
14 | @abstractmethod
15 | def label_names(self):
16 | pass
17 |
18 | @abstractmethod
19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
20 | pass
21 |
22 | @abstractmethod
23 | def index_to_label_name(self, index) -> str:
24 | pass
25 |
26 | @abstractmethod
27 | def index_to_description(self, index: int, detailed: bool = False) -> str:
28 | pass
29 |
30 | @abstractmethod
31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
32 | pass
33 |
34 |
35 | class CustomDatasetInfo(DatasetInfo):
36 | """ DatasetInfo that wraps passed values for custom datasets."""
37 |
38 | def __init__(
39 | self,
40 | label_names: Union[List[str], Dict[int, str]],
41 | label_descriptions: Optional[Dict[str, str]] = None
42 | ):
43 | super().__init__()
44 | assert len(label_names) > 0
45 | self._label_names = label_names # label index => label name mapping
46 | self._label_descriptions = label_descriptions # label name => label description mapping
47 | if self._label_descriptions is not None:
48 | # validate descriptions (label names required)
49 | assert isinstance(self._label_descriptions, dict)
50 | for n in self._label_names:
51 | assert n in self._label_descriptions
52 |
53 | def num_classes(self):
54 | return len(self._label_names)
55 |
56 | def label_names(self):
57 | return self._label_names
58 |
59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
60 | return self._label_descriptions
61 |
62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
63 | if self._label_descriptions:
64 | return self._label_descriptions[label]
65 | return label # return label name itself if a descriptions is not present
66 |
67 | def index_to_label_name(self, index) -> str:
68 | assert 0 <= index < len(self._label_names)
69 | return self._label_names[index]
70 |
71 | def index_to_description(self, index: int, detailed: bool = False) -> str:
72 | label = self.index_to_label_name(index)
73 | return self.label_name_to_description(label, detailed=detailed)
74 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Python tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | env:
10 | OMP_NUM_THREADS: 2
11 | MKL_NUM_THREADS: 2
12 |
13 | jobs:
14 | test:
15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
16 | strategy:
17 | matrix:
18 | os: [ubuntu-latest]
19 | python: ['3.10', '3.13']
20 | torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.9.1', vision: '0.24.1'}]
21 | testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
22 | exclude:
23 | - python: '3.13'
24 | torch: {base: '1.13.0', vision: '0.14.0'}
25 | runs-on: ${{ matrix.os }}
26 |
27 | steps:
28 | - uses: actions/checkout@v2
29 | - name: Set up Python ${{ matrix.python }}
30 | uses: actions/setup-python@v1
31 | with:
32 | python-version: ${{ matrix.python }}
33 | - name: Install testing dependencies
34 | run: |
35 | python -m pip install --upgrade pip
36 | pip install -r requirements-dev.txt
37 | - name: Install torch on mac
38 | if: startsWith(matrix.os, 'macOS')
39 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
40 | - name: Install torch on Windows
41 | if: startsWith(matrix.os, 'windows')
42 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
43 | - name: Install torch on ubuntu
44 | if: startsWith(matrix.os, 'ubuntu')
45 | run: |
46 | sudo sed -i 's/azure\.//' /etc/apt/sources.list
47 | sudo apt update
48 | sudo apt install -y google-perftools
49 | pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu --index-url https://download.pytorch.org/whl/cpu
50 | - name: Install requirements
51 | run: |
52 | pip install -r requirements.txt
53 | - name: Force old numpy for old torch
54 | if: ${{ matrix.torch.base == '1.13.0' }}
55 | run: pip install --upgrade 'numpy<2.0'
56 | - name: Run tests on Windows
57 | if: startsWith(matrix.os, 'windows')
58 | env:
59 | PYTHONDONTWRITEBYTECODE: 1
60 | run: |
61 | pytest -vv tests
62 | - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
63 | if: ${{ !startsWith(matrix.os, 'windows') }}
64 | env:
65 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
66 | PYTHONDONTWRITEBYTECODE: 1
67 | run: |
68 | pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests
69 |
--------------------------------------------------------------------------------
/timm/layers/interpolate.py:
--------------------------------------------------------------------------------
1 | """ Interpolation helpers for timm layers
2 |
3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
4 | Copyright Shane Barratt, Apache 2.0 license
5 | """
6 | import torch
7 | from itertools import product
8 |
9 |
10 | class RegularGridInterpolator:
11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing.
12 | Produces similar results to scipy RegularGridInterpolator or interp2d
13 | in 'linear' mode.
14 |
15 | Taken from https://github.com/sbarratt/torch_interpolations
16 | """
17 |
18 | def __init__(self, points, values):
19 | self.points = points
20 | self.values = values
21 |
22 | assert isinstance(self.points, tuple) or isinstance(self.points, list)
23 | assert isinstance(self.values, torch.Tensor)
24 |
25 | self.ms = list(self.values.shape)
26 | self.n = len(self.points)
27 |
28 | assert len(self.ms) == self.n
29 |
30 | for i, p in enumerate(self.points):
31 | assert isinstance(p, torch.Tensor)
32 | assert p.shape[0] == self.values.shape[i]
33 |
34 | def __call__(self, points_to_interp):
35 | assert self.points is not None
36 | assert self.values is not None
37 |
38 | assert len(points_to_interp) == len(self.points)
39 | K = points_to_interp[0].shape[0]
40 | for x in points_to_interp:
41 | assert x.shape[0] == K
42 |
43 | idxs = []
44 | dists = []
45 | overalls = []
46 | for p, x in zip(self.points, points_to_interp):
47 | idx_right = torch.bucketize(x, p)
48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
50 | dist_left = x - p[idx_left]
51 | dist_right = p[idx_right] - x
52 | dist_left[dist_left < 0] = 0.
53 | dist_right[dist_right < 0] = 0.
54 | both_zero = (dist_left == 0) & (dist_right == 0)
55 | dist_left[both_zero] = dist_right[both_zero] = 1.
56 |
57 | idxs.append((idx_left, idx_right))
58 | dists.append((dist_left, dist_right))
59 | overalls.append(dist_left + dist_right)
60 |
61 | numerator = 0.
62 | for indexer in product([0, 1], repeat=self.n):
63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
65 | numerator += self.values[as_s] * \
66 | torch.prod(torch.stack(bs_s), dim=0)
67 | denominator = torch.prod(torch.stack(overalls), dim=0)
68 | return numerator / denominator
69 |
--------------------------------------------------------------------------------
/timm/loss/binary_cross_entropy.py:
--------------------------------------------------------------------------------
1 | """ Binary Cross Entropy w/ a few extras
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | from typing import Optional, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BinaryCrossEntropy(nn.Module):
13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding
14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove
15 | """
16 | def __init__(
17 | self,
18 | smoothing=0.1,
19 | target_threshold: Optional[float] = None,
20 | weight: Optional[torch.Tensor] = None,
21 | reduction: str = 'mean',
22 | sum_classes: bool = False,
23 | pos_weight: Optional[Union[torch.Tensor, float]] = None,
24 | ):
25 | super(BinaryCrossEntropy, self).__init__()
26 | assert 0. <= smoothing < 1.0
27 | if pos_weight is not None:
28 | if not isinstance(pos_weight, torch.Tensor):
29 | pos_weight = torch.tensor(pos_weight)
30 | self.smoothing = smoothing
31 | self.target_threshold = target_threshold
32 | self.reduction = 'none' if sum_classes else reduction
33 | self.sum_classes = sum_classes
34 | self.register_buffer('weight', weight)
35 | self.register_buffer('pos_weight', pos_weight)
36 |
37 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
38 | batch_size = x.shape[0]
39 | assert batch_size == target.shape[0]
40 |
41 | if target.shape != x.shape:
42 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
43 | num_classes = x.shape[-1]
44 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
45 | off_value = self.smoothing / num_classes
46 | on_value = 1. - self.smoothing + off_value
47 | target = target.long().view(-1, 1)
48 | target = torch.full(
49 | (batch_size, num_classes),
50 | off_value,
51 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
52 |
53 | if self.target_threshold is not None:
54 | # Make target 0, or 1 if threshold set
55 | target = target.gt(self.target_threshold).to(dtype=target.dtype)
56 |
57 | loss = F.binary_cross_entropy_with_logits(
58 | x, target,
59 | self.weight,
60 | pos_weight=self.pos_weight,
61 | reduction=self.reduction,
62 | )
63 | if self.sum_classes:
64 | loss = loss.sum(-1).mean()
65 | return loss
66 |
--------------------------------------------------------------------------------
/timm/optim/sgdp.py:
--------------------------------------------------------------------------------
1 | """
2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | from .adamp import projection
17 |
18 |
19 | class SGDP(Optimizer):
20 | def __init__(
21 | self,
22 | params,
23 | lr=required,
24 | momentum=0,
25 | dampening=0,
26 | weight_decay=0,
27 | nesterov=False,
28 | eps=1e-8,
29 | delta=0.1,
30 | wd_ratio=0.1
31 | ):
32 | defaults = dict(
33 | lr=lr,
34 | momentum=momentum,
35 | dampening=dampening,
36 | weight_decay=weight_decay,
37 | nesterov=nesterov,
38 | eps=eps,
39 | delta=delta,
40 | wd_ratio=wd_ratio,
41 | )
42 | super(SGDP, self).__init__(params, defaults)
43 |
44 | @torch.no_grad()
45 | def step(self, closure=None):
46 | loss = None
47 | if closure is not None:
48 | with torch.enable_grad():
49 | loss = closure()
50 |
51 | for group in self.param_groups:
52 | weight_decay = group['weight_decay']
53 | momentum = group['momentum']
54 | dampening = group['dampening']
55 | nesterov = group['nesterov']
56 |
57 | for p in group['params']:
58 | if p.grad is None:
59 | continue
60 | grad = p.grad
61 | state = self.state[p]
62 |
63 | # State initialization
64 | if len(state) == 0:
65 | state['momentum'] = torch.zeros_like(p)
66 |
67 | # SGD
68 | buf = state['momentum']
69 | buf.mul_(momentum).add_(grad, alpha=1. - dampening)
70 | if nesterov:
71 | d_p = grad + momentum * buf
72 | else:
73 | d_p = buf
74 |
75 | # Projection
76 | wd_ratio = 1.
77 | if len(p.shape) > 1:
78 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
79 |
80 | # Weight decay
81 | if weight_decay != 0:
82 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
83 |
84 | # Step
85 | p.add_(d_p, alpha=-group['lr'])
86 |
87 | return loss
88 |
--------------------------------------------------------------------------------
/timm/optim/lookahead.py:
--------------------------------------------------------------------------------
1 | """ Lookahead Optimizer Wrapper.
2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | from collections import OrderedDict
8 | from typing import Callable, Dict
9 |
10 | import torch
11 | from torch.optim.optimizer import Optimizer
12 | from collections import defaultdict
13 |
14 |
15 | class Lookahead(Optimizer):
16 | def __init__(self, base_optimizer, alpha=0.5, k=6):
17 | # NOTE super().__init__() not called on purpose
18 | self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
19 | self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
20 | if not 0.0 <= alpha <= 1.0:
21 | raise ValueError(f'Invalid slow update rate: {alpha}')
22 | if not 1 <= k:
23 | raise ValueError(f'Invalid lookahead steps: {k}')
24 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
25 | self._base_optimizer = base_optimizer
26 | self.param_groups = base_optimizer.param_groups
27 | self.defaults = base_optimizer.defaults
28 | self.defaults.update(defaults)
29 | self.state = defaultdict(dict)
30 | # manually add our defaults to the param groups
31 | for name, default in defaults.items():
32 | for group in self._base_optimizer.param_groups:
33 | group.setdefault(name, default)
34 |
35 | @torch.no_grad()
36 | def update_slow(self, group):
37 | for fast_p in group["params"]:
38 | if fast_p.grad is None:
39 | continue
40 | param_state = self._base_optimizer.state[fast_p]
41 | if 'lookahead_slow_buff' not in param_state:
42 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
43 | param_state['lookahead_slow_buff'].copy_(fast_p)
44 | slow = param_state['lookahead_slow_buff']
45 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
46 | fast_p.copy_(slow)
47 |
48 | def sync_lookahead(self):
49 | for group in self._base_optimizer.param_groups:
50 | self.update_slow(group)
51 |
52 | @torch.no_grad()
53 | def step(self, closure=None):
54 | loss = self._base_optimizer.step(closure)
55 | for group in self._base_optimizer.param_groups:
56 | group['lookahead_step'] += 1
57 | if group['lookahead_step'] % group['lookahead_k'] == 0:
58 | self.update_slow(group)
59 | return loss
60 |
61 | def state_dict(self):
62 | return self._base_optimizer.state_dict()
63 |
64 | def load_state_dict(self, state_dict):
65 | self._base_optimizer.load_state_dict(state_dict)
66 | self.param_groups = self._base_optimizer.param_groups
67 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_image_tar.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that reads single tarfile based datasets
2 |
3 | This reader can read datasets consisting if a single tarfile containing images.
4 | I am planning to deprecated it in favour of ParerImageInTar.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | import tarfile
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
19 | extensions = get_img_extensions(as_set=True)
20 | files = []
21 | labels = []
22 | for ti in tarfile.getmembers():
23 | if not ti.isfile():
24 | continue
25 | dirname, basename = os.path.split(ti.path)
26 | label = os.path.basename(dirname)
27 | ext = os.path.splitext(basename)[1]
28 | if ext.lower() in extensions:
29 | files.append(ti)
30 | labels.append(label)
31 | if class_to_idx is None:
32 | unique_labels = set(labels)
33 | sorted_labels = list(sorted(unique_labels, key=natural_key))
34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
36 | if sort:
37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
38 | return tarinfo_and_targets, class_to_idx
39 |
40 |
41 | class ReaderImageTar(Reader):
42 | """ Single tarfile dataset where classes are mapped to folders within tar
43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
44 | operate on folders of tars or tars in tars.
45 | """
46 | def __init__(self, root, class_map=''):
47 | super().__init__()
48 |
49 | class_to_idx = None
50 | if class_map:
51 | class_to_idx = load_class_map(class_map, root)
52 | assert os.path.isfile(root)
53 | self.root = root
54 |
55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
57 | self.imgs = self.samples
58 | self.tarfile = None # lazy init in __getitem__
59 |
60 | def __getitem__(self, index):
61 | if self.tarfile is None:
62 | self.tarfile = tarfile.open(self.root)
63 | tarinfo, target = self.samples[index]
64 | fileobj = self.tarfile.extractfile(tarinfo)
65 | return fileobj, target
66 |
67 | def __len__(self):
68 | return len(self.samples)
69 |
70 | def _filename(self, index, basename=False, absolute=False):
71 | filename = self.samples[index][0].name
72 | if basename:
73 | filename = os.path.basename(filename)
74 | return filename
75 |
--------------------------------------------------------------------------------
/timm/layers/pos_embed.py:
--------------------------------------------------------------------------------
1 | """ Position Embedding Utilities
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import logging
6 | import math
7 | from typing import List, Tuple, Optional, Union
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from ._fx import register_notrace_function
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | @torch.fx.wrap
18 | @register_notrace_function
19 | def resample_abs_pos_embed(
20 | posemb: torch.Tensor,
21 | new_size: List[int],
22 | old_size: Optional[List[int]] = None,
23 | num_prefix_tokens: int = 1,
24 | interpolation: str = 'bicubic',
25 | antialias: bool = True,
26 | verbose: bool = False,
27 | ):
28 | # sort out sizes, assume square if old size not provided
29 | num_pos_tokens = posemb.shape[1]
30 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
31 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
32 | return posemb
33 |
34 | if old_size is None:
35 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
36 | old_size = hw, hw
37 |
38 | if num_prefix_tokens:
39 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
40 | else:
41 | posemb_prefix, posemb = None, posemb
42 |
43 | # do the interpolation
44 | embed_dim = posemb.shape[-1]
45 | orig_dtype = posemb.dtype
46 | posemb = posemb.float() # interpolate needs float32
47 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
48 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
49 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
50 | posemb = posemb.to(orig_dtype)
51 |
52 | # add back extra (class, etc) prefix tokens
53 | if posemb_prefix is not None:
54 | posemb = torch.cat([posemb_prefix, posemb], dim=1)
55 |
56 | if not torch.jit.is_scripting() and verbose:
57 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
58 |
59 | return posemb
60 |
61 |
62 | @torch.fx.wrap
63 | @register_notrace_function
64 | def resample_abs_pos_embed_nhwc(
65 | posemb: torch.Tensor,
66 | new_size: List[int],
67 | interpolation: str = 'bicubic',
68 | antialias: bool = True,
69 | verbose: bool = False,
70 | ):
71 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
72 | return posemb
73 |
74 | orig_dtype = posemb.dtype
75 | posemb = posemb.float()
76 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
77 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
78 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
79 |
80 | if not torch.jit.is_scripting() and verbose:
81 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
82 |
83 | return posemb
84 |
--------------------------------------------------------------------------------
/timm/task/classification.py:
--------------------------------------------------------------------------------
1 | """Classification training task."""
2 | import logging
3 | from typing import Callable, Dict, Optional, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .task import TrainingTask
9 |
10 | _logger = logging.getLogger(__name__)
11 |
12 |
13 | class ClassificationTask(TrainingTask):
14 | """Standard supervised classification task.
15 |
16 | Simple task that performs a forward pass through the model and computes
17 | the classification loss.
18 |
19 | Args:
20 | model: The model to train
21 | criterion: Loss function (e.g., CrossEntropyLoss)
22 | device: Device for task tensors/buffers
23 | dtype: Dtype for task tensors/buffers
24 | verbose: Enable info logging
25 |
26 | Example:
27 | >>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda'))
28 | >>> result = task(input, target)
29 | >>> result['loss'].backward()
30 | """
31 |
32 | def __init__(
33 | self,
34 | model: nn.Module,
35 | criterion: Union[nn.Module, Callable],
36 | device: Optional[torch.device] = None,
37 | dtype: Optional[torch.dtype] = None,
38 | verbose: bool = True,
39 | ):
40 | super().__init__(device=device, dtype=dtype, verbose=verbose)
41 | self.model = model
42 | self.criterion = criterion
43 |
44 | if self.verbose:
45 | loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__
46 | _logger.info(f"ClassificationTask: criterion={loss_name}")
47 |
48 | def prepare_distributed(
49 | self,
50 | device_ids: Optional[list] = None,
51 | **ddp_kwargs
52 | ) -> 'ClassificationTask':
53 | """Prepare task for distributed training.
54 |
55 | Wraps the model in DistributedDataParallel (DDP).
56 |
57 | Args:
58 | device_ids: List of device IDs for DDP (e.g., [local_rank])
59 | **ddp_kwargs: Additional arguments passed to DistributedDataParallel
60 |
61 | Returns:
62 | self (for method chaining)
63 | """
64 | from torch.nn.parallel import DistributedDataParallel as DDP
65 | self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs)
66 | return self
67 |
68 | def forward(
69 | self,
70 | input: torch.Tensor,
71 | target: torch.Tensor,
72 | ) -> Dict[str, torch.Tensor]:
73 | """Forward pass through model and compute classification loss.
74 |
75 | Args:
76 | input: Input tensor [B, C, H, W]
77 | target: Target labels [B]
78 |
79 | Returns:
80 | Dictionary containing:
81 | - 'loss': Classification loss
82 | - 'output': Model logits
83 | """
84 | output = self.model(input)
85 | loss = self.criterion(output, target)
86 |
87 | return {
88 | 'loss': loss,
89 | 'output': output,
90 | }
91 |
--------------------------------------------------------------------------------
/results/generate_csv_results.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | results = {
6 | 'results-imagenet.csv': [
7 | 'results-imagenet-real.csv',
8 | 'results-imagenetv2-matched-frequency.csv',
9 | 'results-sketch.csv'
10 | ],
11 | 'results-imagenet-a-clean.csv': [
12 | 'results-imagenet-a.csv',
13 | ],
14 | 'results-imagenet-r-clean.csv': [
15 | 'results-imagenet-r.csv',
16 | ],
17 | }
18 |
19 |
20 | def diff(base_df, test_csv):
21 | base_df['mi'] = base_df.model + '-' + base_df.img_size.astype('str')
22 | base_models = base_df['mi'].values
23 | test_df = pd.read_csv(test_csv)
24 | test_df['mi'] = test_df.model + '-' + test_df.img_size.astype('str')
25 | test_models = test_df['mi'].values
26 |
27 | rank_diff = np.zeros_like(test_models, dtype='object')
28 | top1_diff = np.zeros_like(test_models, dtype='object')
29 | top5_diff = np.zeros_like(test_models, dtype='object')
30 |
31 | for rank, model in enumerate(test_models):
32 | if model in base_models:
33 | base_rank = int(np.where(base_models == model)[0])
34 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank]
35 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank]
36 |
37 | # rank_diff
38 | if rank == base_rank:
39 | rank_diff[rank] = f'0'
40 | elif rank > base_rank:
41 | rank_diff[rank] = f'-{rank - base_rank}'
42 | else:
43 | rank_diff[rank] = f'+{base_rank - rank}'
44 |
45 | # top1_diff
46 | if top1_d >= .0:
47 | top1_diff[rank] = f'+{top1_d:.3f}'
48 | else:
49 | top1_diff[rank] = f'-{abs(top1_d):.3f}'
50 |
51 | # top5_diff
52 | if top5_d >= .0:
53 | top5_diff[rank] = f'+{top5_d:.3f}'
54 | else:
55 | top5_diff[rank] = f'-{abs(top5_d):.3f}'
56 |
57 | else:
58 | rank_diff[rank] = ''
59 | top1_diff[rank] = ''
60 | top5_diff[rank] = ''
61 |
62 | test_df['top1_diff'] = top1_diff
63 | test_df['top5_diff'] = top5_diff
64 | test_df['rank_diff'] = rank_diff
65 |
66 | test_df.drop('mi', axis=1, inplace=True)
67 | base_df.drop('mi', axis=1, inplace=True)
68 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format)
69 | test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True)
70 | test_df.to_csv(test_csv, index=False, float_format='%.3f')
71 |
72 |
73 | for base_results, test_results in results.items():
74 | base_df = pd.read_csv(base_results)
75 | base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True)
76 | for test_csv in test_results:
77 | diff(base_df, test_csv)
78 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format)
79 | base_df.to_csv(base_results, index=False, float_format='%.3f')
80 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_a_synsets.txt:
--------------------------------------------------------------------------------
1 | n01498041
2 | n01531178
3 | n01534433
4 | n01558993
5 | n01580077
6 | n01614925
7 | n01616318
8 | n01631663
9 | n01641577
10 | n01669191
11 | n01677366
12 | n01687978
13 | n01694178
14 | n01698640
15 | n01735189
16 | n01770081
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01819313
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01882714
27 | n01910747
28 | n01914609
29 | n01924916
30 | n01944390
31 | n01985128
32 | n01986214
33 | n02007558
34 | n02009912
35 | n02037110
36 | n02051845
37 | n02077923
38 | n02085620
39 | n02099601
40 | n02106550
41 | n02106662
42 | n02110958
43 | n02119022
44 | n02123394
45 | n02127052
46 | n02129165
47 | n02133161
48 | n02137549
49 | n02165456
50 | n02174001
51 | n02177972
52 | n02190166
53 | n02206856
54 | n02219486
55 | n02226429
56 | n02231487
57 | n02233338
58 | n02236044
59 | n02259212
60 | n02268443
61 | n02279972
62 | n02280649
63 | n02281787
64 | n02317335
65 | n02325366
66 | n02346627
67 | n02356798
68 | n02361337
69 | n02410509
70 | n02445715
71 | n02454379
72 | n02486410
73 | n02492035
74 | n02504458
75 | n02655020
76 | n02669723
77 | n02672831
78 | n02676566
79 | n02690373
80 | n02701002
81 | n02730930
82 | n02777292
83 | n02782093
84 | n02787622
85 | n02793495
86 | n02797295
87 | n02802426
88 | n02814860
89 | n02815834
90 | n02837789
91 | n02879718
92 | n02883205
93 | n02895154
94 | n02906734
95 | n02948072
96 | n02951358
97 | n02980441
98 | n02992211
99 | n02999410
100 | n03014705
101 | n03026506
102 | n03124043
103 | n03125729
104 | n03187595
105 | n03196217
106 | n03223299
107 | n03250847
108 | n03255030
109 | n03291819
110 | n03325584
111 | n03355925
112 | n03384352
113 | n03388043
114 | n03417042
115 | n03443371
116 | n03444034
117 | n03445924
118 | n03452741
119 | n03483316
120 | n03584829
121 | n03590841
122 | n03594945
123 | n03617480
124 | n03666591
125 | n03670208
126 | n03717622
127 | n03720891
128 | n03721384
129 | n03724870
130 | n03775071
131 | n03788195
132 | n03804744
133 | n03837869
134 | n03840681
135 | n03854065
136 | n03888257
137 | n03891332
138 | n03935335
139 | n03982430
140 | n04019541
141 | n04033901
142 | n04039381
143 | n04067472
144 | n04086273
145 | n04099969
146 | n04118538
147 | n04131690
148 | n04133789
149 | n04141076
150 | n04146614
151 | n04147183
152 | n04179913
153 | n04208210
154 | n04235860
155 | n04252077
156 | n04252225
157 | n04254120
158 | n04270147
159 | n04275548
160 | n04310018
161 | n04317175
162 | n04344873
163 | n04347754
164 | n04355338
165 | n04366367
166 | n04376876
167 | n04389033
168 | n04399382
169 | n04442312
170 | n04456115
171 | n04482393
172 | n04507155
173 | n04509417
174 | n04532670
175 | n04540053
176 | n04554684
177 | n04562935
178 | n04591713
179 | n04606251
180 | n07583066
181 | n07695742
182 | n07697313
183 | n07697537
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07749582
189 | n07753592
190 | n07760859
191 | n07768694
192 | n07831146
193 | n09229709
194 | n09246464
195 | n09472597
196 | n09835506
197 | n11879895
198 | n12057211
199 | n12144580
200 | n12267677
201 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_r_synsets.txt:
--------------------------------------------------------------------------------
1 | n01443537
2 | n01484850
3 | n01494475
4 | n01498041
5 | n01514859
6 | n01518878
7 | n01531178
8 | n01534433
9 | n01614925
10 | n01616318
11 | n01630670
12 | n01632777
13 | n01644373
14 | n01677366
15 | n01694178
16 | n01748264
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01806143
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01860187
27 | n01882714
28 | n01910747
29 | n01944390
30 | n01983481
31 | n01986214
32 | n02007558
33 | n02009912
34 | n02051845
35 | n02056570
36 | n02066245
37 | n02071294
38 | n02077923
39 | n02085620
40 | n02086240
41 | n02088094
42 | n02088238
43 | n02088364
44 | n02088466
45 | n02091032
46 | n02091134
47 | n02092339
48 | n02094433
49 | n02096585
50 | n02097298
51 | n02098286
52 | n02099601
53 | n02099712
54 | n02102318
55 | n02106030
56 | n02106166
57 | n02106550
58 | n02106662
59 | n02108089
60 | n02108915
61 | n02109525
62 | n02110185
63 | n02110341
64 | n02110958
65 | n02112018
66 | n02112137
67 | n02113023
68 | n02113624
69 | n02113799
70 | n02114367
71 | n02117135
72 | n02119022
73 | n02123045
74 | n02128385
75 | n02128757
76 | n02129165
77 | n02129604
78 | n02130308
79 | n02134084
80 | n02138441
81 | n02165456
82 | n02190166
83 | n02206856
84 | n02219486
85 | n02226429
86 | n02233338
87 | n02236044
88 | n02268443
89 | n02279972
90 | n02317335
91 | n02325366
92 | n02346627
93 | n02356798
94 | n02363005
95 | n02364673
96 | n02391049
97 | n02395406
98 | n02398521
99 | n02410509
100 | n02423022
101 | n02437616
102 | n02445715
103 | n02447366
104 | n02480495
105 | n02480855
106 | n02481823
107 | n02483362
108 | n02486410
109 | n02510455
110 | n02526121
111 | n02607072
112 | n02655020
113 | n02672831
114 | n02701002
115 | n02749479
116 | n02769748
117 | n02793495
118 | n02797295
119 | n02802426
120 | n02808440
121 | n02814860
122 | n02823750
123 | n02841315
124 | n02843684
125 | n02883205
126 | n02906734
127 | n02909870
128 | n02939185
129 | n02948072
130 | n02950826
131 | n02951358
132 | n02966193
133 | n02980441
134 | n02992529
135 | n03124170
136 | n03272010
137 | n03345487
138 | n03372029
139 | n03424325
140 | n03452741
141 | n03467068
142 | n03481172
143 | n03494278
144 | n03495258
145 | n03498962
146 | n03594945
147 | n03602883
148 | n03630383
149 | n03649909
150 | n03676483
151 | n03710193
152 | n03773504
153 | n03775071
154 | n03888257
155 | n03930630
156 | n03947888
157 | n04086273
158 | n04118538
159 | n04133789
160 | n04141076
161 | n04146614
162 | n04147183
163 | n04192698
164 | n04254680
165 | n04266014
166 | n04275548
167 | n04310018
168 | n04325704
169 | n04347754
170 | n04389033
171 | n04409515
172 | n04465501
173 | n04487394
174 | n04522168
175 | n04536866
176 | n04552348
177 | n04591713
178 | n07614500
179 | n07693725
180 | n07695742
181 | n07697313
182 | n07697537
183 | n07714571
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07742313
189 | n07745940
190 | n07749582
191 | n07753275
192 | n07753592
193 | n07768694
194 | n07873807
195 | n07880968
196 | n07920052
197 | n09472597
198 | n09835506
199 | n10565667
200 | n12267677
201 |
--------------------------------------------------------------------------------
/timm/layers/global_context.py:
--------------------------------------------------------------------------------
1 | """ Global Context Attention Block
2 |
3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
4 | - https://arxiv.org/abs/1904.11492
5 |
6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from typing import Optional, Tuple, Type, Union
11 |
12 | from torch import nn as nn
13 | import torch.nn.functional as F
14 |
15 | from .create_act import create_act_layer, get_act_layer
16 | from .helpers import make_divisible
17 | from .mlp import ConvMlp
18 | from .norm import LayerNorm2d
19 |
20 |
21 | class GlobalContext(nn.Module):
22 |
23 | def __init__(
24 | self,
25 | channels: int,
26 | use_attn: bool = True,
27 | fuse_add: bool = False,
28 | fuse_scale: bool = True,
29 | init_last_zero: bool = False,
30 | rd_ratio: float = 1./8,
31 | rd_channels: Optional[int] = None,
32 | rd_divisor: int = 1,
33 | act_layer: Type[nn.Module] = nn.ReLU,
34 | gate_layer: Union[str, Type[nn.Module]] = 'sigmoid',
35 | device=None,
36 | dtype=None
37 | ):
38 | dd = {'device': device, 'dtype': dtype}
39 | super().__init__()
40 | act_layer = get_act_layer(act_layer)
41 |
42 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True, **dd) if use_attn else None
43 |
44 | if rd_channels is None:
45 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
46 | if fuse_add:
47 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd)
48 | else:
49 | self.mlp_add = None
50 | if fuse_scale:
51 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d, **dd)
52 | else:
53 | self.mlp_scale = None
54 |
55 | self.gate = create_act_layer(gate_layer)
56 | self.init_last_zero = init_last_zero
57 |
58 | self.reset_parameters()
59 |
60 | def reset_parameters(self):
61 | if self.conv_attn is not None:
62 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
63 | if self.mlp_add is not None:
64 | nn.init.zeros_(self.mlp_add.fc2.weight)
65 |
66 | def forward(self, x):
67 | B, C, H, W = x.shape
68 |
69 | if self.conv_attn is not None:
70 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
71 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
72 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
73 | context = context.view(B, C, 1, 1)
74 | else:
75 | context = x.mean(dim=(2, 3), keepdim=True)
76 |
77 | if self.mlp_scale is not None:
78 | mlp_x = self.mlp_scale(context)
79 | x = x * self.gate(mlp_x)
80 | if self.mlp_add is not None:
81 | mlp_x = self.mlp_add(context)
82 | x = x + mlp_x
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/timm/layers/filter_response_norm.py:
--------------------------------------------------------------------------------
1 | """ Filter Response Norm in PyTorch
2 |
3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | from typing import Optional, Type
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from .create_act import create_act_layer
13 | from .trace_utils import _assert
14 |
15 |
16 | def inv_instance_rms(x, eps: float = 1e-5):
17 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
18 | return rms.expand(x.shape)
19 |
20 |
21 | class FilterResponseNormTlu2d(nn.Module):
22 | def __init__(
23 | self,
24 | num_features: int,
25 | apply_act: bool = True,
26 | eps: float = 1e-5,
27 | rms: bool = True,
28 | device=None,
29 | dtype=None,
30 | **_,
31 | ):
32 | dd = {'device': device, 'dtype': dtype}
33 | super().__init__()
34 | self.apply_act = apply_act # apply activation (non-linearity)
35 | self.rms = rms
36 | self.eps = eps
37 | self.weight = nn.Parameter(torch.empty(num_features, **dd))
38 | self.bias = nn.Parameter(torch.empty(num_features, **dd))
39 | self.tau = nn.Parameter(torch.empty(num_features, **dd)) if apply_act else None
40 |
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | nn.init.ones_(self.weight)
45 | nn.init.zeros_(self.bias)
46 | if self.tau is not None:
47 | nn.init.zeros_(self.tau)
48 |
49 | def forward(self, x):
50 | _assert(x.dim() == 4, 'expected 4D input')
51 | x_dtype = x.dtype
52 | v_shape = (1, -1, 1, 1)
53 | x = x * inv_instance_rms(x, self.eps)
54 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
55 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
56 |
57 |
58 | class FilterResponseNormAct2d(nn.Module):
59 | def __init__(
60 | self,
61 | num_features: int,
62 | apply_act: bool = True,
63 | act_layer: Type[nn.Module] = nn.ReLU,
64 | inplace: Optional[bool] = None,
65 | rms: bool = True,
66 | eps: float = 1e-5,
67 | device=None,
68 | dtype=None,
69 | **_,
70 | ):
71 | dd = {'device': device, 'dtype': dtype}
72 | super().__init__()
73 | if act_layer is not None and apply_act:
74 | self.act = create_act_layer(act_layer, inplace=inplace)
75 | else:
76 | self.act = nn.Identity()
77 | self.rms = rms
78 | self.eps = eps
79 | self.weight = nn.Parameter(torch.empty(num_features, **dd))
80 | self.bias = nn.Parameter(torch.empty(num_features, **dd))
81 |
82 | self.reset_parameters()
83 |
84 | def reset_parameters(self):
85 | nn.init.ones_(self.weight)
86 | nn.init.zeros_(self.bias)
87 |
88 | def forward(self, x):
89 | _assert(x.dim() == 4, 'expected 4D input')
90 | x_dtype = x.dtype
91 | v_shape = (1, -1, 1, 1)
92 | x = x * inv_instance_rms(x, self.eps)
93 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
94 | return self.act(x)
95 |
--------------------------------------------------------------------------------
/timm/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition
2 | from timm.layers.activations import *
3 | from timm.layers.adaptive_avgmax_pool import \
4 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
5 | from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
6 | from timm.layers.blur_pool import BlurPool2d
7 | from timm.layers.classifier import ClassifierHead, create_classifier
8 | from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer
9 | from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
10 | set_layer_config
11 | from timm.layers.conv2d_same import Conv2dSame, conv2d_same
12 | from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
13 | from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn
14 | from timm.layers.create_attn import get_attn, create_attn
15 | from timm.layers.create_conv2d import create_conv2d
16 | from timm.layers.create_norm import get_norm_layer, create_norm_layer
17 | from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
18 | from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path
19 | from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
20 | from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
22 | from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
23 | from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
24 | from timm.layers.gather_excite import GatherExcite
25 | from timm.layers.global_context import GlobalContext
26 | from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
27 | from timm.layers.inplace_abn import InplaceAbn
28 | from timm.layers.linear import Linear
29 | from timm.layers.mixed_conv2d import MixedConv2d
30 | from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp
31 | from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn
32 | from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
33 | from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
34 | from timm.layers.padding import get_padding, get_same_padding, pad_same
35 | from timm.layers.patch_embed import PatchEmbed
36 | from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
37 | from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
38 | from timm.layers.selective_kernel import SelectiveKernel
39 | from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
40 | from timm.layers.split_attn import SplitAttn
41 | from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
42 | from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
43 | from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
44 | from timm.layers.trace_utils import _assert, _float_to_int
45 | from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
46 |
47 | import warnings
48 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
49 |
--------------------------------------------------------------------------------
/timm/utils/attention_extract.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import re
3 | from collections import OrderedDict
4 | from typing import Union, Optional, List
5 |
6 | import torch
7 |
8 |
9 | class AttentionExtract(torch.nn.Module):
10 | # defaults should cover a significant number of timm models with attention maps.
11 | default_node_names = ['*attn.softmax']
12 | default_module_names = ['*attn_drop']
13 |
14 | def __init__(
15 | self,
16 | model: Union[torch.nn.Module],
17 | names: Optional[List[str]] = None,
18 | mode: str = 'eval',
19 | method: str = 'fx',
20 | hook_type: str = 'forward',
21 | use_regex: bool = False,
22 | ):
23 | """ Extract attention maps (or other activations) from a model by name.
24 |
25 | Args:
26 | model: Instantiated model to extract from.
27 | names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
28 | mode: 'train' or 'eval' model mode.
29 | method: 'fx' or 'hook' extraction method.
30 | hook_type: 'forward' or 'forward_pre' hooks used.
31 | use_regex: Use regex instead of fnmatch
32 | """
33 | super().__init__()
34 | assert mode in ('train', 'eval')
35 | if mode == 'train':
36 | model = model.train()
37 | else:
38 | model = model.eval()
39 |
40 | assert method in ('fx', 'hook')
41 | if method == 'fx':
42 | # names are activation node names
43 | from timm.models._features_fx import get_graph_node_names, GraphExtractNet
44 |
45 | node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
46 | names = names or self.default_node_names
47 | if use_regex:
48 | regexes = [re.compile(r) for r in names]
49 | matched = [g for g in node_names if any([r.match(g) for r in regexes])]
50 | else:
51 | matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
52 | if not matched:
53 | raise RuntimeError(f'No node names found matching {names}.')
54 |
55 | self.model = GraphExtractNet(model, matched, return_dict=True)
56 | self.hooks = None
57 | else:
58 | # names are module names
59 | assert hook_type in ('forward', 'forward_pre')
60 | from timm.models._features import FeatureHooks
61 |
62 | module_names = [n for n, m in model.named_modules()]
63 | names = names or self.default_module_names
64 | if use_regex:
65 | regexes = [re.compile(r) for r in names]
66 | matched = [m for m in module_names if any([r.match(m) for r in regexes])]
67 | else:
68 | matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
69 | if not matched:
70 | raise RuntimeError(f'No module names found matching {names}.')
71 |
72 | self.model = model
73 | self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
74 |
75 | self.names = matched
76 | self.mode = mode
77 | self.method = method
78 |
79 | def forward(self, x):
80 | if self.hooks is not None:
81 | self.model(x)
82 | output = self.hooks.get_output(device=x.device)
83 | else:
84 | output = self.model(x)
85 | return output
86 |
--------------------------------------------------------------------------------
/timm/layers/patch_dropout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def patch_dropout_forward(
8 | x: torch.Tensor,
9 | prob: float,
10 | num_prefix_tokens: int,
11 | ordered: bool,
12 | training: bool,
13 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
14 | """
15 | Common forward logic for patch dropout.
16 |
17 | Args:
18 | x: Input tensor of shape (B, L, D)
19 | prob: Dropout probability
20 | num_prefix_tokens: Number of prefix tokens to preserve
21 | ordered: Whether to maintain patch order
22 | training: Whether in training mode
23 |
24 | Returns:
25 | Tuple of (output tensor, keep_indices or None)
26 | """
27 | if not training or prob == 0.:
28 | return x, None
29 |
30 | if num_prefix_tokens:
31 | prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:]
32 | else:
33 | prefix_tokens = None
34 |
35 | B = x.shape[0]
36 | L = x.shape[1]
37 | num_keep = max(1, int(L * (1. - prob)))
38 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
39 |
40 | if ordered:
41 | # NOTE does not need to maintain patch order in typical transformer use,
42 | # but possibly useful for debug / visualization
43 | keep_indices = keep_indices.sort(dim=-1)[0]
44 |
45 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
46 |
47 | if prefix_tokens is not None:
48 | x = torch.cat((prefix_tokens, x), dim=1)
49 |
50 | return x, keep_indices
51 |
52 |
53 | class PatchDropout(nn.Module):
54 | """
55 | Patch Dropout without returning indices.
56 | https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
57 | """
58 |
59 | def __init__(
60 | self,
61 | prob: float = 0.5,
62 | num_prefix_tokens: int = 1,
63 | ordered: bool = False,
64 | ):
65 | super().__init__()
66 | assert 0 <= prob < 1.
67 | self.prob = prob
68 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
69 | self.ordered = ordered
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | output, _ = patch_dropout_forward(
73 | x,
74 | self.prob,
75 | self.num_prefix_tokens,
76 | self.ordered,
77 | self.training
78 | )
79 | return output
80 |
81 |
82 | class PatchDropoutWithIndices(nn.Module):
83 | """
84 | Patch Dropout that returns both output and keep indices.
85 | https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
86 | """
87 |
88 | def __init__(
89 | self,
90 | prob: float = 0.5,
91 | num_prefix_tokens: int = 1,
92 | ordered: bool = False,
93 | ):
94 | super().__init__()
95 | assert 0 <= prob < 1.
96 | self.prob = prob
97 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
98 | self.ordered = ordered
99 |
100 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
101 | return patch_dropout_forward(
102 | x,
103 | self.prob,
104 | self.num_prefix_tokens,
105 | self.ordered,
106 | self.training
107 | )
108 |
--------------------------------------------------------------------------------
/timm/layers/conv_bn_act.py:
--------------------------------------------------------------------------------
1 | """ Conv2d + BN + Act
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from typing import Any, Dict, Optional, Type
6 |
7 | from torch import nn as nn
8 |
9 | from .typing import LayerType, PadType
10 | from .blur_pool import create_aa
11 | from .create_conv2d import create_conv2d
12 | from .create_norm_act import get_norm_act_layer
13 |
14 |
15 | class ConvNormAct(nn.Module):
16 | def __init__(
17 | self,
18 | in_channels: int,
19 | out_channels: int,
20 | kernel_size: int = 1,
21 | stride: int = 1,
22 | padding: PadType = '',
23 | dilation: int = 1,
24 | groups: int = 1,
25 | bias: bool = False,
26 | apply_norm: bool = True,
27 | apply_act: bool = True,
28 | norm_layer: LayerType = nn.BatchNorm2d,
29 | act_layer: Optional[LayerType] = nn.ReLU,
30 | aa_layer: Optional[LayerType] = None,
31 | drop_layer: Optional[Type[nn.Module]] = None,
32 | conv_kwargs: Optional[Dict[str, Any]] = None,
33 | norm_kwargs: Optional[Dict[str, Any]] = None,
34 | act_kwargs: Optional[Dict[str, Any]] = None,
35 | device=None,
36 | dtype=None,
37 | ):
38 | dd = {'device': device, 'dtype': dtype}
39 | super().__init__()
40 | conv_kwargs = {**dd, **(conv_kwargs or {})}
41 | norm_kwargs = {**dd, **(norm_kwargs or {})}
42 | act_kwargs = act_kwargs or {}
43 | use_aa = aa_layer is not None and stride > 1
44 |
45 | self.conv = create_conv2d(
46 | in_channels,
47 | out_channels,
48 | kernel_size,
49 | stride=1 if use_aa else stride,
50 | padding=padding,
51 | dilation=dilation,
52 | groups=groups,
53 | bias=bias,
54 | **conv_kwargs,
55 | )
56 |
57 | if apply_norm:
58 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
59 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
60 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
61 | if drop_layer:
62 | norm_kwargs['drop_layer'] = drop_layer
63 | self.bn = norm_act_layer(
64 | out_channels,
65 | apply_act=apply_act,
66 | act_kwargs=act_kwargs,
67 | **norm_kwargs,
68 | )
69 | else:
70 | self.bn = nn.Sequential()
71 | if drop_layer:
72 | norm_kwargs['drop_layer'] = drop_layer
73 | self.bn.add_module('drop', drop_layer())
74 |
75 | self.aa = create_aa(
76 | aa_layer,
77 | out_channels,
78 | stride=stride,
79 | enable=use_aa,
80 | noop=None,
81 | **dd,
82 | )
83 |
84 | @property
85 | def in_channels(self):
86 | return self.conv.in_channels
87 |
88 | @property
89 | def out_channels(self):
90 | return self.conv.out_channels
91 |
92 | def forward(self, x):
93 | x = self.conv(x)
94 | x = self.bn(x)
95 | aa = getattr(self, 'aa', None)
96 | if aa is not None:
97 | x = self.aa(x)
98 | return x
99 |
100 |
101 | ConvBnAct = ConvNormAct
102 | ConvNormActAa = ConvNormAct # backwards compat, when they were separate
103 |
--------------------------------------------------------------------------------
/hfdocs/source/hparams.mdx:
--------------------------------------------------------------------------------
1 | # HParams
2 | Over the years, many `timm` models have been trained with various hyper-parameters as the libraries and models evolved. I don't have a record of every instance, but have recorded instances of many that can serve as a very good starting point.
3 |
4 | ## Tags
5 | Most `timm` trained models have an identifier in their pretrained tag that relates them (roughly) to a family / version of hparams I've used over the years.
6 |
7 | | Tag(s) | Description | Optimizer | LR Schedule | Other Notes |
8 | |--------|-------------|-----------|-------------|-------------|
9 | | `a1h` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `A1` recipe | LAMB | Cosine with warmup | Stronger dropout, stochastic depth, and RandAugment than paper `A1` recipe |
10 | | `ah` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `A1` recipe | LAMB | Cosine with warmup | No CutMix. Stronger dropout, stochastic depth, and RandAugment than paper `A1` recipe |
11 | | `a1`, `a2`, `a3` | ResNet Strikes Back `A{1,2,3}` recipe | LAMB with BCE loss | Cosine with warmup | — |
12 | | `b1`, `b2`, `b1k`, `b2k` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `B` recipe (equivalent to `timm` `RA2` recipes) | RMSProp (TF 1.0 behaviour) | Step (exponential decay w/ staircase) with warmup | — |
13 | | `c`, `c1`, `c2`, `c3` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `C` recipes | SGD (Nesterov) with AGC | Cosine with warmup | — |
14 | | `ch` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `C` recipes | SGD (Nesterov) with AGC | Cosine with warmup | Stronger dropout, stochastic depth, and RandAugment than paper `C1`/`C2` recipes |
15 | | `d`, `d1`, `d2` | Based on [ResNet Strikes Back](https://arxiv.org/abs/2110.00476) `D` recipe | AdamW with BCE loss | Cosine with warmup | — |
16 | | `sw` | Based on Swin Transformer train/pretrain recipe (basis of DeiT and ConvNeXt recipes) | AdamW with gradient clipping, EMA | Cosine with warmup | — |
17 | | `ra`, `ra2`, `ra3`, `racm`, `raa` | RandAugment recipes. Inspired by EfficientNet RandAugment recipes. Covered by `B` recipe in [ResNet Strikes Back](https://arxiv.org/abs/2110.00476). | RMSProp (TF 1.0 behaviour), EMA | Step (exponential decay w/ staircase) with warmup | — |
18 | | `ra4` | RandAugment v4. Inspired by MobileNetV4 hparams. | - |
19 | | `am` | AugMix recipe | SGD (Nesterov) with JSD loss | Cosine with warmup | — |
20 | | `ram` | AugMix (with RandAugment) recipe | SGD (Nesterov) with JSD loss | Cosine with warmup | — |
21 | | `bt` | Bag-of-Tricks recipe | SGD (Nesterov) | Cosine with warmup | — |
22 |
23 | ## Config File Gists
24 | I've collected several of the hparam families in a series of gists. These can be downloaded and used with the `--config hparam.yaml` argument with the `timm` train script. Some adjustment is always required for the LR vs effective global batch size.
25 |
26 | | Tag | Key Model Architectures | Gist Link |
27 | |-----|------------------------|-----------|
28 | | `ra2` | ResNet, EfficientNet, RegNet, NFNet | [Link](https://gist.github.com/rwightman/07839a82d0f50e42840168bc43df70b3) |
29 | | `ra3` | RegNet | [Link](https://gist.github.com/rwightman/37252f8d7d850a94e43f1fcb7b3b8322) |
30 | | `ra4` | MobileNetV4 | [Link](https://gist.github.com/rwightman/f6705cb65c03daeebca8aa129b1b94ad) |
31 | | `sw` | ViT, ConvNeXt, CoAtNet, MaxViT | [Link](https://gist.github.com/rwightman/943c0fe59293b44024bbd2d5d23e6303) |
32 | | `sbb` | ViT | [Link](https://gist.github.com/rwightman/fb37c339efd2334177ff99a8083ebbc4) |
33 | | — | Tiny Test Models | [Link](https://gist.github.com/rwightman/9ba8efc39a546426e99055720d2f705f) |
34 |
--------------------------------------------------------------------------------
/timm/loss/asymmetric_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AsymmetricLossMultiLabel(nn.Module):
6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
7 | super(AsymmetricLossMultiLabel, self).__init__()
8 |
9 | self.gamma_neg = gamma_neg
10 | self.gamma_pos = gamma_pos
11 | self.clip = clip
12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
13 | self.eps = eps
14 |
15 | def forward(self, x, y):
16 | """"
17 | Parameters
18 | ----------
19 | x: input logits
20 | y: targets (multi-label binarized vector)
21 | """
22 |
23 | # Calculating Probabilities
24 | x_sigmoid = torch.sigmoid(x)
25 | xs_pos = x_sigmoid
26 | xs_neg = 1 - x_sigmoid
27 |
28 | # Asymmetric Clipping
29 | if self.clip is not None and self.clip > 0:
30 | xs_neg = (xs_neg + self.clip).clamp(max=1)
31 |
32 | # Basic CE calculation
33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
35 | loss = los_pos + los_neg
36 |
37 | # Asymmetric Focusing
38 | if self.gamma_neg > 0 or self.gamma_pos > 0:
39 | if self.disable_torch_grad_focal_loss:
40 | torch.set_grad_enabled(False)
41 | pt0 = xs_pos * y
42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
43 | pt = pt0 + pt1
44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma)
46 | if self.disable_torch_grad_focal_loss:
47 | torch.set_grad_enabled(True)
48 | loss *= one_sided_w
49 |
50 | return -loss.sum()
51 |
52 |
53 | class AsymmetricLossSingleLabel(nn.Module):
54 | def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
55 | super(AsymmetricLossSingleLabel, self).__init__()
56 |
57 | self.eps = eps
58 | self.logsoftmax = nn.LogSoftmax(dim=-1)
59 | self.targets_classes = [] # prevent gpu repeated memory allocation
60 | self.gamma_pos = gamma_pos
61 | self.gamma_neg = gamma_neg
62 | self.reduction = reduction
63 |
64 | def forward(self, inputs, target, reduction=None):
65 | """"
66 | Parameters
67 | ----------
68 | x: input logits
69 | y: targets (1-hot vector)
70 | """
71 |
72 | num_classes = inputs.size()[-1]
73 | log_preds = self.logsoftmax(inputs)
74 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
75 |
76 | # ASL weights
77 | targets = self.targets_classes
78 | anti_targets = 1 - targets
79 | xs_pos = torch.exp(log_preds)
80 | xs_neg = 1 - xs_pos
81 | xs_pos = xs_pos * targets
82 | xs_neg = xs_neg * anti_targets
83 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
84 | self.gamma_pos * targets + self.gamma_neg * anti_targets)
85 | log_preds = log_preds * asymmetric_w
86 |
87 | if self.eps > 0: # label smoothing
88 | self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
89 |
90 | # loss calculation
91 | loss = - self.targets_classes.mul(log_preds)
92 |
93 | loss = loss.sum(dim=-1)
94 | if self.reduction == 'mean':
95 | loss = loss.mean()
96 |
97 | return loss
98 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_hfds.py:
--------------------------------------------------------------------------------
1 | """ Dataset reader that wraps Hugging Face datasets
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import io
6 | import math
7 | from typing import Optional
8 |
9 | import torch
10 | import torch.distributed as dist
11 | from PIL import Image
12 |
13 | try:
14 | import datasets
15 | except ImportError as e:
16 | print("Please install Hugging Face datasets package `pip install datasets`.")
17 | raise e
18 | from .class_map import load_class_map
19 | from .reader import Reader
20 |
21 |
22 | def get_class_labels(info, label_key='label'):
23 | if 'label' not in info.features:
24 | return {}
25 | class_label = info.features[label_key]
26 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
27 | return class_to_idx
28 |
29 |
30 | class ReaderHfds(Reader):
31 |
32 | def __init__(
33 | self,
34 | name: str,
35 | root: Optional[str] = None,
36 | split: str = 'train',
37 | class_map: dict = None,
38 | input_key: str = 'image',
39 | target_key: str = 'label',
40 | additional_features: Optional[list[str]] = None,
41 | download: bool = False,
42 | trust_remote_code: bool = False
43 | ):
44 | """
45 | """
46 | super().__init__()
47 | self.root = root
48 | self.split = split
49 | self.dataset = datasets.load_dataset(
50 | name, # 'name' maps to path arg in hf datasets
51 | split=split,
52 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
53 | trust_remote_code=trust_remote_code
54 | )
55 | # leave decode for caller, plus we want easy access to original path names...
56 | self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
57 |
58 | self.image_key = input_key
59 | self.label_key = target_key
60 | self.remap_class = False
61 | if class_map:
62 | self.class_to_idx = load_class_map(class_map)
63 | self.remap_class = True
64 | else:
65 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
66 | self.split_info = self.dataset.info.splits[split]
67 | self.num_samples = self.split_info.num_examples
68 |
69 | if additional_features is not None:
70 | if isinstance(additional_features, list):
71 | self.additional_features = additional_features
72 | else:
73 | self.additional_features = [additional_features]
74 | else:
75 | self.additional_features = None
76 |
77 | def __getitem__(self, index):
78 | item = self.dataset[index]
79 | image = item[self.image_key]
80 |
81 | if 'bytes' in image and image['bytes']:
82 | image = io.BytesIO(image['bytes'])
83 | else:
84 | assert 'path' in image and image['path']
85 | image = open(image['path'], 'rb')
86 |
87 | label = item[self.label_key]
88 | if self.remap_class:
89 | label = self.class_to_idx[label]
90 |
91 | if self.additional_features is not None:
92 | features = [item[feat] for feat in self.additional_features]
93 | return image, label, *features
94 | else:
95 | return image, label
96 |
97 | def __len__(self):
98 | return len(self.dataset)
99 |
100 | def _filename(self, index, basename=False, absolute=False):
101 | item = self.dataset[index]
102 | return item[self.image_key]['path']
103 |
--------------------------------------------------------------------------------
/timm/models/_features_fx.py:
--------------------------------------------------------------------------------
1 | """ PyTorch FX Based Feature Extraction Helpers
2 | Using https://pytorch.org/vision/stable/feature_extraction.html
3 | """
4 | from typing import Callable, Dict, List, Optional, Union, Tuple, Type
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from timm.layers import (
10 | create_feature_extractor,
11 | get_graph_node_names,
12 | register_notrace_module,
13 | register_notrace_function,
14 | is_notrace_module,
15 | is_notrace_function,
16 | get_notrace_functions,
17 | get_notrace_modules,
18 | Format,
19 | )
20 | from ._features import _get_feature_info, _get_return_layers
21 |
22 |
23 |
24 | __all__ = [
25 | 'register_notrace_module',
26 | 'is_notrace_module',
27 | 'get_notrace_modules',
28 | 'register_notrace_function',
29 | 'is_notrace_function',
30 | 'get_notrace_functions',
31 | 'create_feature_extractor',
32 | 'get_graph_node_names',
33 | 'FeatureGraphNet',
34 | 'GraphExtractNet',
35 | ]
36 |
37 |
38 | class FeatureGraphNet(nn.Module):
39 | """ A FX Graph based feature extractor that works with the model feature_info metadata
40 | """
41 | return_dict: torch.jit.Final[bool]
42 |
43 | def __init__(
44 | self,
45 | model: nn.Module,
46 | out_indices: Tuple[int, ...],
47 | out_map: Optional[Dict] = None,
48 | output_fmt: str = 'NCHW',
49 | return_dict: bool = False,
50 | ):
51 | super().__init__()
52 | self.feature_info = _get_feature_info(model, out_indices)
53 | if out_map is not None:
54 | assert len(out_map) == len(out_indices)
55 | self.output_fmt = Format(output_fmt)
56 | return_nodes = _get_return_layers(self.feature_info, out_map)
57 | self.graph_module = create_feature_extractor(model, return_nodes)
58 | self.return_dict = return_dict
59 |
60 | def forward(self, x):
61 | out = self.graph_module(x)
62 | if self.return_dict:
63 | return out
64 | return list(out.values())
65 |
66 |
67 | class GraphExtractNet(nn.Module):
68 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor
69 | NOTE:
70 | * one can use feature_extractor directly if dictionary output is desired
71 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
72 | metadata for builtin feature extraction mode
73 | * create_feature_extractor can be used directly if dictionary output is acceptable
74 |
75 | Args:
76 | model: model to extract features from
77 | return_nodes: node names to return features from (dict or list)
78 | squeeze_out: if only one output, and output in list format, flatten to single tensor
79 | return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
80 | """
81 | return_dict: torch.jit.Final[bool]
82 |
83 | def __init__(
84 | self,
85 | model: nn.Module,
86 | return_nodes: Union[Dict[str, str], List[str]],
87 | squeeze_out: bool = True,
88 | return_dict: bool = False,
89 | ):
90 | super().__init__()
91 | self.squeeze_out = squeeze_out
92 | self.graph_module = create_feature_extractor(model, return_nodes)
93 | self.return_dict = return_dict
94 |
95 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
96 | out = self.graph_module(x)
97 | if self.return_dict:
98 | return out
99 | out = list(out.values())
100 | return out[0] if self.squeeze_out and len(out) == 1 else out
101 |
--------------------------------------------------------------------------------
/timm/layers/padding.py:
--------------------------------------------------------------------------------
1 | """ Padding Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import math
6 | from typing import List, Tuple, Union
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 | from .helpers import to_2tuple
12 |
13 |
14 | # Calculate symmetric padding for a convolution
15 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> Union[int, List[int]]:
16 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
17 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
18 | return [get_padding(*a) for a in zip(kernel_size, stride, dilation)]
19 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
20 | return padding
21 |
22 |
23 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
24 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
25 | if isinstance(x, torch.Tensor):
26 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
27 | else:
28 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
29 |
30 |
31 | # Can SAME padding for given args be done statically?
32 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
33 | if any([isinstance(v, (tuple, list)) for v in [kernel_size, stride, dilation]]):
34 | kernel_size, stride, dilation = to_2tuple(kernel_size), to_2tuple(stride), to_2tuple(dilation)
35 | return all([is_static_pad(*a) for a in zip(kernel_size, stride, dilation)])
36 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
37 |
38 |
39 | def pad_same_arg(
40 | input_size: List[int],
41 | kernel_size: List[int],
42 | stride: List[int],
43 | dilation: List[int] = (1, 1),
44 | ) -> List[int]:
45 | ih, iw = input_size
46 | kh, kw = kernel_size
47 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
48 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
49 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
50 |
51 |
52 | # Dynamically pad input x with 'SAME' padding for conv with specified args
53 | def pad_same(
54 | x,
55 | kernel_size: List[int],
56 | stride: List[int],
57 | dilation: List[int] = (1, 1),
58 | value: float = 0,
59 | ):
60 | ih, iw = x.size()[-2:]
61 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
62 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
63 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
64 | return x
65 |
66 |
67 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
68 | dynamic = False
69 | if isinstance(padding, str):
70 | # for any string padding, the padding will be calculated for you, one of three ways
71 | padding = padding.lower()
72 | if padding == 'same':
73 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
74 | if is_static_pad(kernel_size, **kwargs):
75 | # static case, no extra overhead
76 | padding = get_padding(kernel_size, **kwargs)
77 | else:
78 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
79 | padding = 0
80 | dynamic = True
81 | elif padding == 'valid':
82 | # 'VALID' padding, same as padding=0
83 | padding = 0
84 | else:
85 | # Default to PyTorch style 'same'-ish symmetric padding
86 | padding = get_padding(kernel_size, **kwargs)
87 | return padding, dynamic
88 |
--------------------------------------------------------------------------------
/timm/models/_pretrained.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import deque, defaultdict
3 | from dataclasses import dataclass, field, replace, asdict
4 | from typing import Any, Deque, Dict, Tuple, Optional, Union
5 |
6 |
7 | __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
8 |
9 |
10 | @dataclass
11 | class PretrainedCfg:
12 | """
13 | """
14 | # weight source locations
15 | url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
16 | file: Optional[str] = None # local / shared filesystem path
17 | state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
18 | hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
19 | hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
20 |
21 | source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
22 | architecture: Optional[str] = None # architecture variant can be set when not implicit
23 | tag: Optional[str] = None # pretrained tag of source
24 | custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
25 |
26 | # input / data config
27 | input_size: Tuple[int, int, int] = (3, 224, 224)
28 | test_input_size: Optional[Tuple[int, int, int]] = None
29 | min_input_size: Optional[Tuple[int, int, int]] = None
30 | fixed_input_size: bool = False
31 | interpolation: str = 'bicubic'
32 | crop_pct: float = 0.875
33 | test_crop_pct: Optional[float] = None
34 | crop_mode: str = 'center'
35 | mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
36 | std: Tuple[float, ...] = (0.229, 0.224, 0.225)
37 |
38 | # head / classifier config and meta-data
39 | num_classes: int = 1000
40 | label_offset: Optional[int] = None
41 | label_names: Optional[Tuple[str]] = None
42 | label_descriptions: Optional[Dict[str, str]] = None
43 |
44 | # model attributes that vary with above or required for pretrained adaptation
45 | pool_size: Optional[Tuple[int, ...]] = None
46 | test_pool_size: Optional[Tuple[int, ...]] = None
47 | first_conv: Optional[str] = None
48 | classifier: Optional[str] = None
49 |
50 | license: Optional[str] = None
51 | description: Optional[str] = None
52 | origin_url: Optional[str] = None
53 | paper_name: Optional[str] = None
54 | paper_ids: Optional[Union[str, Tuple[str]]] = None
55 | notes: Optional[Tuple[str]] = None
56 |
57 | @property
58 | def has_weights(self):
59 | return self.url or self.file or self.hf_hub_id
60 |
61 | def to_dict(self, remove_source=False, remove_null=True):
62 | return filter_pretrained_cfg(
63 | asdict(self),
64 | remove_source=remove_source,
65 | remove_null=remove_null
66 | )
67 |
68 |
69 | def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
70 | filtered_cfg = {}
71 | keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
72 | for k, v in cfg.items():
73 | if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
74 | continue
75 | if remove_null and v is None and k not in keep_null:
76 | continue
77 | filtered_cfg[k] = v
78 | return filtered_cfg
79 |
80 |
81 | @dataclass
82 | class DefaultCfg:
83 | tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
84 | cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
85 | is_pretrained: bool = False # at least one of the configs has a pretrained source set
86 |
87 | @property
88 | def default(self):
89 | return self.cfgs[self.tags[0]]
90 |
91 | @property
92 | def default_with_tag(self):
93 | tag = self.tags[0]
94 | return tag, self.cfgs[tag]
95 |
--------------------------------------------------------------------------------
/timm/task/task.py:
--------------------------------------------------------------------------------
1 | """Base training task abstraction.
2 |
3 | This module provides the base TrainingTask class that encapsulates a complete
4 | forward pass including loss computation. Tasks return a dictionary with loss
5 | components and outputs for logging.
6 | """
7 | from typing import Dict, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class TrainingTask(nn.Module):
14 | """Base class for training tasks.
15 |
16 | A training task encapsulates a complete forward pass including loss computation.
17 | Tasks return a dictionary containing the training loss and other components for logging.
18 |
19 | The returned dictionary must contain:
20 | - 'loss': The training loss for backward pass (required)
21 | - 'output': Model output/logits for metric computation (recommended)
22 | - Other task-specific loss components for logging (optional)
23 |
24 | Args:
25 | device: Device for task tensors/buffers (defaults to cpu)
26 | dtype: Dtype for task tensors/buffers (defaults to torch default)
27 | verbose: Enable info logging
28 |
29 | Example:
30 | >>> task = SomeTask(model, criterion, device=torch.device('cuda'))
31 | >>>
32 | >>> # Prepare for distributed training (if needed)
33 | >>> if distributed:
34 | >>> task.prepare_distributed(device_ids=[local_rank])
35 | >>>
36 | >>> # Training loop
37 | >>> result = task(input, target)
38 | >>> result['loss'].backward()
39 | """
40 |
41 | def __init__(
42 | self,
43 | device: Optional[torch.device] = None,
44 | dtype: Optional[torch.dtype] = None,
45 | verbose: bool = True,
46 | ):
47 | super().__init__()
48 | self.device = device if device is not None else torch.device('cpu')
49 | self.dtype = dtype if dtype is not None else torch.get_default_dtype()
50 | self.verbose = verbose
51 |
52 | def to(self, *args, **kwargs):
53 | """Move task to device/dtype, keeping self.device and self.dtype in sync."""
54 | dummy = torch.empty(0).to(*args, **kwargs)
55 | self.device = dummy.device
56 | self.dtype = dummy.dtype
57 | return super().to(*args, **kwargs)
58 |
59 | def prepare_distributed(
60 | self,
61 | device_ids: Optional[list] = None,
62 | **ddp_kwargs
63 | ) -> 'TrainingTask':
64 | """Prepare task for distributed training.
65 |
66 | This method wraps trainable components in DistributedDataParallel (DDP)
67 | while leaving non-trainable components (like frozen teacher models) unwrapped.
68 |
69 | Should be called after task initialization but before training loop.
70 |
71 | Args:
72 | device_ids: List of device IDs for DDP (e.g., [local_rank])
73 | **ddp_kwargs: Additional arguments passed to DistributedDataParallel
74 |
75 | Returns:
76 | self (for method chaining)
77 |
78 | Example:
79 | >>> task = LogitDistillationTask(student, teacher, criterion)
80 | >>> task.prepare_distributed(device_ids=[args.local_rank])
81 | >>> task = torch.compile(task) # Compile after DDP
82 | """
83 | # Default implementation - subclasses override if they need DDP
84 | return self
85 |
86 | def forward(
87 | self,
88 | input: torch.Tensor,
89 | target: torch.Tensor,
90 | ) -> Dict[str, torch.Tensor]:
91 | """Perform forward pass and compute loss.
92 |
93 | Args:
94 | input: Input tensor [B, C, H, W]
95 | target: Target labels [B]
96 |
97 | Returns:
98 | Dictionary with at least 'loss' key containing the training loss
99 | """
100 | raise NotImplementedError
101 |
--------------------------------------------------------------------------------
/timm/layers/split_batchnorm.py:
--------------------------------------------------------------------------------
1 | """ Split BatchNorm
2 |
3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6 | namespace.
7 |
8 | This allows easily removing the auxiliary BN layers after training to efficiently
9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10 | 'Disentangled Learning via An Auxiliary BN'
11 |
12 | Hacked together by / Copyright 2020 Ross Wightman
13 | """
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d):
19 |
20 | def __init__(
21 | self,
22 | num_features,
23 | eps=1e-5,
24 | momentum=0.1,
25 | affine=True,
26 | track_running_stats=True,
27 | num_splits=2,
28 | device=None,
29 | dtype=None,
30 | ):
31 | dd = {'device': device, 'dtype': dtype}
32 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
33 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
34 | self.num_splits = num_splits
35 | self.aux_bn = nn.ModuleList([
36 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats, **dd)
37 | for _ in range(num_splits - 1)
38 | ])
39 |
40 | def forward(self, input: torch.Tensor):
41 | if self.training: # aux BN only relevant while training
42 | split_size = input.shape[0] // self.num_splits
43 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
44 | split_input = input.split(split_size)
45 | x = [super().forward(split_input[0])]
46 | for i, a in enumerate(self.aux_bn):
47 | x.append(a(split_input[i + 1]))
48 | return torch.cat(x, dim=0)
49 | else:
50 | return super().forward(input)
51 |
52 |
53 | def convert_splitbn_model(module, num_splits=2):
54 | """
55 | Recursively traverse module and its children to replace all instances of
56 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
57 | Args:
58 | module (torch.nn.Module): input module
59 | num_splits: number of separate batchnorm layers to split input across
60 | Example::
61 | >>> # model is an instance of torch.nn.Module
62 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
63 | """
64 | mod = module
65 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
66 | return module
67 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
68 | mod = SplitBatchNorm2d(
69 | module.num_features, module.eps, module.momentum, module.affine,
70 | module.track_running_stats, num_splits=num_splits)
71 | mod.running_mean = module.running_mean
72 | mod.running_var = module.running_var
73 | mod.num_batches_tracked = module.num_batches_tracked
74 | if module.affine:
75 | mod.weight.data = module.weight.data.clone().detach()
76 | mod.bias.data = module.bias.data.clone().detach()
77 | for aux in mod.aux_bn:
78 | aux.running_mean = module.running_mean.clone()
79 | aux.running_var = module.running_var.clone()
80 | aux.num_batches_tracked = module.num_batches_tracked.clone()
81 | if module.affine:
82 | aux.weight.data = module.weight.data.clone().detach()
83 | aux.bias.data = module.bias.data.clone().detach()
84 | for name, child in module.named_children():
85 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
86 | del module
87 | return mod
88 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_image_folder.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that extracts images from folders
2 |
3 | Folders are scanned recursively to find image files. Labels are based
4 | on the folder hierarchy, just leaf folders by default.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | from typing import Dict, List, Optional, Set, Tuple, Union
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def find_images_and_targets(
19 | folder: str,
20 | types: Optional[Union[List, Tuple, Set]] = None,
21 | class_to_idx: Optional[Dict] = None,
22 | leaf_name_only: bool = True,
23 | sort: bool = True
24 | ):
25 | """ Walk folder recursively to discover images and map them to classes by folder names.
26 |
27 | Args:
28 | folder: root of folder to recursively search
29 | types: types (file extensions) to search for in path
30 | class_to_idx: specify mapping for class (folder name) to class index if set
31 | leaf_name_only: use only leaf-name of folder walk for class names
32 | sort: re-sort found images by name (for consistent ordering)
33 |
34 | Returns:
35 | A list of image and target tuples, class_to_idx mapping
36 | """
37 | types = get_img_extensions(as_set=True) if not types else set(types)
38 | labels = []
39 | filenames = []
40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
41 | rel_path = os.path.relpath(root, folder) if (root != folder) else ''
42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
43 | for f in files:
44 | base, ext = os.path.splitext(f)
45 | if ext.lower() in types:
46 | filenames.append(os.path.join(root, f))
47 | labels.append(label)
48 | if class_to_idx is None:
49 | # building class index
50 | unique_labels = set(labels)
51 | sorted_labels = list(sorted(unique_labels, key=natural_key))
52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
54 | if sort:
55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
56 | return images_and_targets, class_to_idx
57 |
58 |
59 | class ReaderImageFolder(Reader):
60 |
61 | def __init__(
62 | self,
63 | root,
64 | class_map='',
65 | input_key=None,
66 | ):
67 | super().__init__()
68 |
69 | self.root = root
70 | class_to_idx = None
71 | if class_map:
72 | class_to_idx = load_class_map(class_map, root)
73 | find_types = None
74 | if input_key:
75 | find_types = input_key.split(';')
76 | self.samples, self.class_to_idx = find_images_and_targets(
77 | root,
78 | class_to_idx=class_to_idx,
79 | types=find_types,
80 | )
81 | if len(self.samples) == 0:
82 | raise RuntimeError(
83 | f'Found 0 images in subfolders of {root}. '
84 | f'Supported image extensions are {", ".join(get_img_extensions())}')
85 |
86 | def __getitem__(self, index):
87 | path, target = self.samples[index]
88 | return open(path, 'rb'), target
89 |
90 | def __len__(self):
91 | return len(self.samples)
92 |
93 | def _filename(self, index, basename=False, absolute=False):
94 | filename = self.samples[index][0]
95 | if basename:
96 | filename = os.path.basename(filename)
97 | elif not absolute:
98 | filename = os.path.relpath(filename, self.root)
99 | return filename
100 |
--------------------------------------------------------------------------------
/timm/layers/inplace_abn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | try:
5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync
6 | has_iabn = True
7 | except ImportError:
8 | has_iabn = False
9 |
10 | def inplace_abn(x, weight, bias, running_mean, running_var,
11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12 | raise ImportError(
13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14 |
15 | def inplace_abn_sync(**kwargs):
16 | inplace_abn(**kwargs)
17 |
18 | from ._fx import register_notrace_module
19 |
20 |
21 | @register_notrace_module
22 | class InplaceAbn(nn.Module):
23 | """Activated Batch Normalization
24 |
25 | This gathers a BatchNorm and an activation function in a single module
26 |
27 | Parameters
28 | ----------
29 | num_features : int
30 | Number of feature channels in the input and output.
31 | eps : float
32 | Small constant to prevent numerical issues.
33 | momentum : float
34 | Momentum factor applied to compute running statistics.
35 | affine : bool
36 | If `True` apply learned scale and shift transformation after normalization.
37 | act_layer : str or nn.Module type
38 | Name or type of the activation functions, one of: `leaky_relu`, `elu`
39 | act_param : float
40 | Negative slope for the `leaky_relu` activation.
41 | """
42 |
43 | def __init__(
44 | self,
45 | num_features,
46 | eps=1e-5,
47 | momentum=0.1,
48 | affine=True,
49 | apply_act=True,
50 | act_layer="leaky_relu",
51 | act_param=0.01,
52 | drop_layer=None,
53 | ):
54 | super().__init__()
55 | self.num_features = num_features
56 | self.affine = affine
57 | self.eps = eps
58 | self.momentum = momentum
59 | if apply_act:
60 | if isinstance(act_layer, str):
61 | assert act_layer in ('leaky_relu', 'elu', 'identity', '')
62 | self.act_name = act_layer if act_layer else 'identity'
63 | else:
64 | # convert act layer passed as type to string
65 | if act_layer == nn.ELU:
66 | self.act_name = 'elu'
67 | elif act_layer == nn.LeakyReLU:
68 | self.act_name = 'leaky_relu'
69 | elif act_layer is None or act_layer == nn.Identity:
70 | self.act_name = 'identity'
71 | else:
72 | assert False, f'Invalid act layer {act_layer.__name__} for IABN'
73 | else:
74 | self.act_name = 'identity'
75 | self.act_param = act_param
76 | if self.affine:
77 | self.weight = nn.Parameter(torch.ones(num_features))
78 | self.bias = nn.Parameter(torch.zeros(num_features))
79 | else:
80 | self.register_parameter('weight', None)
81 | self.register_parameter('bias', None)
82 | self.register_buffer('running_mean', torch.zeros(num_features))
83 | self.register_buffer('running_var', torch.ones(num_features))
84 | self.reset_parameters()
85 |
86 | def reset_parameters(self):
87 | nn.init.constant_(self.running_mean, 0)
88 | nn.init.constant_(self.running_var, 1)
89 | if self.affine:
90 | nn.init.constant_(self.weight, 1)
91 | nn.init.constant_(self.bias, 0)
92 |
93 | def forward(self, x):
94 | output = inplace_abn(
95 | x, self.weight, self.bias, self.running_mean, self.running_var,
96 | self.training, self.momentum, self.eps, self.act_name, self.act_param)
97 | if isinstance(output, tuple):
98 | output = output[0]
99 | return output
100 |
--------------------------------------------------------------------------------
/timm/layers/pool2d_same.py:
--------------------------------------------------------------------------------
1 | """ AvgPool2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import List, Tuple, Optional, Union
9 |
10 | from ._fx import register_notrace_module
11 | from .helpers import to_2tuple
12 | from .padding import pad_same, get_padding_value
13 |
14 |
15 | def avg_pool2d_same(
16 | x: torch.Tensor,
17 | kernel_size: List[int],
18 | stride: List[int],
19 | padding: List[int] = (0, 0),
20 | ceil_mode: bool = False,
21 | count_include_pad: bool = True,
22 | ):
23 | # FIXME how to deal with count_include_pad vs not for external padding?
24 | x = pad_same(x, kernel_size, stride)
25 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
26 |
27 |
28 | @register_notrace_module
29 | class AvgPool2dSame(nn.AvgPool2d):
30 | """ Tensorflow like 'SAME' wrapper for 2D average pooling
31 | """
32 | def __init__(
33 | self,
34 | kernel_size: Union[int, Tuple[int, int]],
35 | stride: Optional[Union[int, Tuple[int, int]]] = None,
36 | padding: Union[int, Tuple[int, int], str] = 0,
37 | ceil_mode: bool = False,
38 | count_include_pad: bool = True,
39 | ):
40 | kernel_size = to_2tuple(kernel_size)
41 | stride = to_2tuple(stride)
42 | super().__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
43 |
44 | def forward(self, x):
45 | x = pad_same(x, self.kernel_size, self.stride)
46 | return F.avg_pool2d(
47 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
48 |
49 |
50 | def max_pool2d_same(
51 | x: torch.Tensor,
52 | kernel_size: List[int],
53 | stride: List[int],
54 | padding: List[int] = (0, 0),
55 | dilation: List[int] = (1, 1),
56 | ceil_mode: bool = False,
57 | ):
58 | x = pad_same(x, kernel_size, stride, value=-float('inf'))
59 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
60 |
61 |
62 | @register_notrace_module
63 | class MaxPool2dSame(nn.MaxPool2d):
64 | """ Tensorflow like 'SAME' wrapper for 2D max pooling
65 | """
66 | def __init__(
67 | self,
68 | kernel_size: Union[int, Tuple[int, int]],
69 | stride: Optional[Union[int, Tuple[int, int]]] = None,
70 | padding: Union[int, Tuple[int, int], str] = 0,
71 | dilation: Union[int, Tuple[int, int]] = 1,
72 | ceil_mode: bool = False,
73 | ):
74 | kernel_size = to_2tuple(kernel_size)
75 | stride = to_2tuple(stride)
76 | dilation = to_2tuple(dilation)
77 | super().__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
78 |
79 | def forward(self, x):
80 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
81 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
82 |
83 |
84 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
85 | stride = stride or kernel_size
86 | padding = kwargs.pop('padding', '')
87 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
88 | if is_dynamic:
89 | if pool_type == 'avg':
90 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
91 | elif pool_type == 'max':
92 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
93 | else:
94 | assert False, f'Unsupported pool type {pool_type}'
95 | else:
96 | if pool_type == 'avg':
97 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
98 | elif pool_type == 'max':
99 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
100 | else:
101 | assert False, f'Unsupported pool type {pool_type}'
102 |
--------------------------------------------------------------------------------
/timm/utils/onnx.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, List
2 |
3 | import torch
4 |
5 |
6 | def onnx_forward(onnx_file, example_input):
7 | import onnxruntime
8 |
9 | sess_options = onnxruntime.SessionOptions()
10 | session = onnxruntime.InferenceSession(onnx_file, sess_options)
11 | input_name = session.get_inputs()[0].name
12 | output = session.run([], {input_name: example_input.numpy()})
13 | output = output[0]
14 | return output
15 |
16 |
17 | def onnx_export(
18 | model: torch.nn.Module,
19 | output_file: str,
20 | example_input: Optional[torch.Tensor] = None,
21 | training: bool = False,
22 | verbose: bool = False,
23 | check: bool = True,
24 | check_forward: bool = False,
25 | batch_size: int = 64,
26 | input_size: Tuple[int, int, int] = None,
27 | opset: Optional[int] = None,
28 | dynamic_size: bool = False,
29 | aten_fallback: bool = False,
30 | keep_initializers: Optional[bool] = None,
31 | use_dynamo: bool = False,
32 | input_names: List[str] = None,
33 | output_names: List[str] = None,
34 | ):
35 | import onnx
36 |
37 | if training:
38 | training_mode = torch.onnx.TrainingMode.TRAINING
39 | model.train()
40 | else:
41 | training_mode = torch.onnx.TrainingMode.EVAL
42 | model.eval()
43 |
44 | if example_input is None:
45 | if not input_size:
46 | assert hasattr(model, 'default_cfg'), 'Cannot file model default config, input size must be provided'
47 | input_size = model.default_cfg.get('input_size')
48 | example_input = torch.randn((batch_size,) + input_size, requires_grad=training)
49 |
50 | # Run model once before export trace, sets padding for models with Conv2dSameExport. This means
51 | # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
52 | # the input img_size specified in this script.
53 |
54 | # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
55 | # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
56 | # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
57 | with torch.inference_mode():
58 | original_out = model(example_input)
59 |
60 | input_names = input_names or ["input0"]
61 | output_names = output_names or ["output0"]
62 |
63 | dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
64 | if dynamic_size:
65 | dynamic_axes['input0'][2] = 'height'
66 | dynamic_axes['input0'][3] = 'width'
67 |
68 | if aten_fallback:
69 | export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
70 | else:
71 | export_type = torch.onnx.OperatorExportTypes.ONNX
72 |
73 | if use_dynamo:
74 | export_options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_size)
75 | export_output = torch.onnx.dynamo_export(
76 | model,
77 | example_input,
78 | export_options=export_options,
79 | )
80 | export_output.save(output_file)
81 | else:
82 | torch.onnx.export(
83 | model,
84 | example_input,
85 | output_file,
86 | training=training_mode,
87 | export_params=True,
88 | verbose=verbose,
89 | input_names=input_names,
90 | output_names=output_names,
91 | keep_initializers_as_inputs=keep_initializers,
92 | dynamic_axes=dynamic_axes,
93 | opset_version=opset,
94 | operator_export_type=export_type
95 | )
96 |
97 | if check:
98 | onnx_model = onnx.load(output_file)
99 | onnx.checker.check_model(onnx_model, full_check=True) # assuming throw on error
100 | if check_forward and not training:
101 | import numpy as np
102 | onnx_out = onnx_forward(output_file, example_input)
103 | np.testing.assert_almost_equal(original_out.numpy(), onnx_out, decimal=3)
104 |
105 |
--------------------------------------------------------------------------------
/timm/layers/separable_conv.py:
--------------------------------------------------------------------------------
1 | """ Depthwise Separable Conv Modules
2 |
3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | from typing import Optional, Type, Union
9 |
10 | from torch import nn as nn
11 |
12 | from .create_conv2d import create_conv2d
13 | from .create_norm_act import get_norm_act_layer
14 |
15 |
16 | class SeparableConvNormAct(nn.Module):
17 | """ Separable Conv w/ trailing Norm and Activation
18 | """
19 | def __init__(
20 | self,
21 | in_channels: int,
22 | out_channels: int,
23 | kernel_size: int = 3,
24 | stride: int = 1,
25 | dilation: int = 1,
26 | padding: str = '',
27 | bias: bool = False,
28 | channel_multiplier: float = 1.0,
29 | pw_kernel_size: int = 1,
30 | norm_layer: Type[nn.Module] = nn.BatchNorm2d,
31 | act_layer: Type[nn.Module] = nn.ReLU,
32 | apply_act: bool = True,
33 | drop_layer: Optional[Type[nn.Module]] = None,
34 | device=None,
35 | dtype=None,
36 | ):
37 | dd = {'device': device, 'dtype': dtype}
38 | super().__init__()
39 |
40 | self.conv_dw = create_conv2d(
41 | in_channels,
42 | int(in_channels * channel_multiplier),
43 | kernel_size,
44 | stride=stride,
45 | dilation=dilation,
46 | padding=padding,
47 | depthwise=True,
48 | **dd,
49 | )
50 |
51 | self.conv_pw = create_conv2d(
52 | int(in_channels * channel_multiplier),
53 | out_channels,
54 | pw_kernel_size,
55 | padding=padding,
56 | bias=bias,
57 | **dd,
58 | )
59 |
60 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
61 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
62 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs, **dd)
63 |
64 | @property
65 | def in_channels(self):
66 | return self.conv_dw.in_channels
67 |
68 | @property
69 | def out_channels(self):
70 | return self.conv_pw.out_channels
71 |
72 | def forward(self, x):
73 | x = self.conv_dw(x)
74 | x = self.conv_pw(x)
75 | x = self.bn(x)
76 | return x
77 |
78 |
79 | SeparableConvBnAct = SeparableConvNormAct
80 |
81 |
82 | class SeparableConv2d(nn.Module):
83 | """ Separable Conv
84 | """
85 | def __init__(
86 | self,
87 | in_channels,
88 | out_channels,
89 | kernel_size=3,
90 | stride=1,
91 | dilation=1,
92 | padding='',
93 | bias=False,
94 | channel_multiplier=1.0,
95 | pw_kernel_size=1,
96 | device=None,
97 | dtype=None,
98 | ):
99 | dd = {'device': device, 'dtype': dtype}
100 | super().__init__()
101 |
102 | self.conv_dw = create_conv2d(
103 | in_channels,
104 | int(in_channels * channel_multiplier),
105 | kernel_size,
106 | stride=stride,
107 | dilation=dilation,
108 | padding=padding,
109 | depthwise=True,
110 | **dd,
111 | )
112 |
113 | self.conv_pw = create_conv2d(
114 | int(in_channels * channel_multiplier),
115 | out_channels,
116 | pw_kernel_size,
117 | padding=padding,
118 | bias=bias,
119 | **dd,
120 | )
121 |
122 | @property
123 | def in_channels(self):
124 | return self.conv_dw.in_channels
125 |
126 | @property
127 | def out_channels(self):
128 | return self.conv_pw.out_channels
129 |
130 | def forward(self, x):
131 | x = self.conv_dw(x)
132 | x = self.conv_pw(x)
133 | return x
134 |
--------------------------------------------------------------------------------
/timm/layers/split_attn.py:
--------------------------------------------------------------------------------
1 | """ Split Attention Conv2d (for ResNeSt Models)
2 |
3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4 |
5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6 |
7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8 | """
9 | from typing import Optional, Type, Union
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 |
15 | from .helpers import make_divisible
16 |
17 |
18 | class RadixSoftmax(nn.Module):
19 | def __init__(self, radix: int, cardinality: int):
20 | super().__init__()
21 | self.radix = radix
22 | self.cardinality = cardinality
23 |
24 | def forward(self, x):
25 | batch = x.size(0)
26 | if self.radix > 1:
27 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
28 | x = F.softmax(x, dim=1)
29 | x = x.reshape(batch, -1)
30 | else:
31 | x = torch.sigmoid(x)
32 | return x
33 |
34 |
35 | class SplitAttn(nn.Module):
36 | """Split-Attention (aka Splat)
37 | """
38 | def __init__(
39 | self,
40 | in_channels: int,
41 | out_channels: Optional[int] = None,
42 | kernel_size: int = 3,
43 | stride: int = 1,
44 | padding: Optional[int] = None,
45 | dilation: int = 1,
46 | groups: int = 1,
47 | bias: bool = False,
48 | radix: int = 2,
49 | rd_ratio: float = 0.25,
50 | rd_channels: Optional[int] = None,
51 | rd_divisor: int = 8,
52 | act_layer: Type[nn.Module] = nn.ReLU,
53 | norm_layer: Optional[Type[nn.Module]] = None,
54 | drop_layer: Optional[Type[nn.Module]] = None,
55 | **kwargs,
56 | ):
57 | dd = {'device': kwargs.pop('device', None), 'dtype': kwargs.pop('dtype', None)}
58 | super().__init__()
59 | out_channels = out_channels or in_channels
60 | self.radix = radix
61 | mid_chs = out_channels * radix
62 | if rd_channels is None:
63 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
64 | else:
65 | attn_chs = rd_channels * radix
66 |
67 | padding = kernel_size // 2 if padding is None else padding
68 | self.conv = nn.Conv2d(
69 | in_channels,
70 | mid_chs,
71 | kernel_size,
72 | stride,
73 | padding,
74 | dilation,
75 | groups=groups * radix,
76 | bias=bias,
77 | **kwargs,
78 | **dd,
79 | )
80 | self.bn0 = norm_layer(mid_chs, **dd) if norm_layer else nn.Identity()
81 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
82 | self.act0 = act_layer(inplace=True)
83 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups, **dd)
84 | self.bn1 = norm_layer(attn_chs, **dd) if norm_layer else nn.Identity()
85 | self.act1 = act_layer(inplace=True)
86 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups, **dd)
87 | self.rsoftmax = RadixSoftmax(radix, groups)
88 |
89 | def forward(self, x):
90 | x = self.conv(x)
91 | x = self.bn0(x)
92 | x = self.drop(x)
93 | x = self.act0(x)
94 |
95 | B, RC, H, W = x.shape
96 | if self.radix > 1:
97 | x = x.reshape((B, self.radix, RC // self.radix, H, W))
98 | x_gap = x.sum(dim=1)
99 | else:
100 | x_gap = x
101 | x_gap = x_gap.mean((2, 3), keepdim=True)
102 | x_gap = self.fc1(x_gap)
103 | x_gap = self.bn1(x_gap)
104 | x_gap = self.act1(x_gap)
105 | x_attn = self.fc2(x_gap)
106 |
107 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
108 | if self.radix > 1:
109 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
110 | else:
111 | out = x * x_attn
112 | return out.contiguous()
113 |
--------------------------------------------------------------------------------
/timm/scheduler/plateau_lr.py:
--------------------------------------------------------------------------------
1 | """ Plateau Scheduler
2 |
3 | Adapts PyTorch plateau scheduler and allows application of noise, warmup.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 | from typing import List, Optional
9 |
10 | from .scheduler import Scheduler
11 |
12 |
13 | class PlateauLRScheduler(Scheduler):
14 | """Decay the LR by a factor every time the validation loss plateaus."""
15 |
16 | def __init__(
17 | self,
18 | optimizer,
19 | decay_rate=0.1,
20 | patience_t=10,
21 | verbose=True,
22 | threshold=1e-4,
23 | cooldown_t=0,
24 | warmup_t=0,
25 | warmup_lr_init=0,
26 | lr_min=0,
27 | mode='max',
28 | noise_range_t=None,
29 | noise_type='normal',
30 | noise_pct=0.67,
31 | noise_std=1.0,
32 | noise_seed=None,
33 | initialize=True,
34 | ):
35 | super().__init__(
36 | optimizer,
37 | 'lr',
38 | noise_range_t=noise_range_t,
39 | noise_type=noise_type,
40 | noise_pct=noise_pct,
41 | noise_std=noise_std,
42 | noise_seed=noise_seed,
43 | initialize=initialize,
44 | )
45 |
46 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
47 | self.optimizer,
48 | patience=patience_t,
49 | factor=decay_rate,
50 | verbose=verbose,
51 | threshold=threshold,
52 | cooldown=cooldown_t,
53 | mode=mode,
54 | min_lr=lr_min
55 | )
56 |
57 | self.warmup_t = warmup_t
58 | self.warmup_lr_init = warmup_lr_init
59 | if self.warmup_t:
60 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
61 | super().update_groups(self.warmup_lr_init)
62 | else:
63 | self.warmup_steps = [1 for _ in self.base_values]
64 | self.restore_lr = None
65 |
66 | def state_dict(self):
67 | return {
68 | 'best': self.lr_scheduler.best,
69 | 'last_epoch': self.lr_scheduler.last_epoch,
70 | }
71 |
72 | def load_state_dict(self, state_dict):
73 | self.lr_scheduler.best = state_dict['best']
74 | if 'last_epoch' in state_dict:
75 | self.lr_scheduler.last_epoch = state_dict['last_epoch']
76 |
77 | # override the base class step fn completely
78 | def step(self, epoch, metric=None):
79 | if epoch <= self.warmup_t:
80 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
81 | super().update_groups(lrs)
82 | else:
83 | if self.restore_lr is not None:
84 | # restore actual LR from before our last noise perturbation before stepping base
85 | for i, param_group in enumerate(self.optimizer.param_groups):
86 | param_group['lr'] = self.restore_lr[i]
87 | self.restore_lr = None
88 |
89 | # step the base scheduler if metric given
90 | if metric is not None:
91 | self.lr_scheduler.step(metric, epoch)
92 |
93 | if self._is_apply_noise(epoch):
94 | self._apply_noise(epoch)
95 |
96 | def step_update(self, num_updates: int, metric: Optional[float] = None):
97 | return None
98 |
99 | def _apply_noise(self, epoch):
100 | noise = self._calculate_noise(epoch)
101 |
102 | # apply the noise on top of previous LR, cache the old value so we can restore for normal
103 | # stepping of base scheduler
104 | restore_lr = []
105 | for i, param_group in enumerate(self.optimizer.param_groups):
106 | old_lr = float(param_group['lr'])
107 | restore_lr.append(old_lr)
108 | new_lr = old_lr + old_lr * noise
109 | param_group['lr'] = new_lr
110 | self.restore_lr = restore_lr
111 |
112 | def _get_lr(self, t: int) -> List[float]:
113 | assert False, 'should not be called as step is overridden'
114 |
--------------------------------------------------------------------------------
/timm/scheduler/tanh_lr.py:
--------------------------------------------------------------------------------
1 | """ TanH Scheduler
2 |
3 | TanH schedule with warmup, cycle/restarts, noise.
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 | from typing import List
12 |
13 | from .scheduler import Scheduler
14 |
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | class TanhLRScheduler(Scheduler):
20 | """
21 | Hyberbolic-Tangent decay with restarts.
22 | This is described in the paper https://arxiv.org/abs/1806.01593
23 | """
24 |
25 | def __init__(
26 | self,
27 | optimizer: torch.optim.Optimizer,
28 | t_initial: int,
29 | lb: float = -7.,
30 | ub: float = 3.,
31 | lr_min: float = 0.,
32 | cycle_mul: float = 1.,
33 | cycle_decay: float = 1.,
34 | cycle_limit: int = 1,
35 | warmup_t=0,
36 | warmup_lr_init=0,
37 | warmup_prefix=False,
38 | t_in_epochs=True,
39 | noise_range_t=None,
40 | noise_pct=0.67,
41 | noise_std=1.0,
42 | noise_seed=42,
43 | initialize=True,
44 | ) -> None:
45 | super().__init__(
46 | optimizer,
47 | param_group_field="lr",
48 | t_in_epochs=t_in_epochs,
49 | noise_range_t=noise_range_t,
50 | noise_pct=noise_pct,
51 | noise_std=noise_std,
52 | noise_seed=noise_seed,
53 | initialize=initialize,
54 | )
55 |
56 | assert t_initial > 0
57 | assert lr_min >= 0
58 | assert lb < ub
59 | assert cycle_limit >= 0
60 | assert warmup_t >= 0
61 | assert warmup_lr_init >= 0
62 | self.lb = lb
63 | self.ub = ub
64 | self.t_initial = t_initial
65 | self.lr_min = lr_min
66 | self.cycle_mul = cycle_mul
67 | self.cycle_decay = cycle_decay
68 | self.cycle_limit = cycle_limit
69 | self.warmup_t = warmup_t
70 | self.warmup_lr_init = warmup_lr_init
71 | self.warmup_prefix = warmup_prefix
72 | if self.warmup_t:
73 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
74 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
75 | super().update_groups(self.warmup_lr_init)
76 | else:
77 | self.warmup_steps = [1 for _ in self.base_values]
78 |
79 | def _get_lr(self, t: int) -> List[float]:
80 | if t < self.warmup_t:
81 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
82 | else:
83 | if self.warmup_prefix:
84 | t = t - self.warmup_t
85 |
86 | if self.cycle_mul != 1:
87 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
88 | t_i = self.cycle_mul ** i * self.t_initial
89 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
90 | else:
91 | i = t // self.t_initial
92 | t_i = self.t_initial
93 | t_curr = t - (self.t_initial * i)
94 |
95 | if i < self.cycle_limit:
96 | gamma = self.cycle_decay ** i
97 | lr_max_values = [v * gamma for v in self.base_values]
98 |
99 | tr = t_curr / t_i
100 | lrs = [
101 | self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
102 | for lr_max in lr_max_values
103 | ]
104 | else:
105 | lrs = [self.lr_min for _ in self.base_values]
106 | return lrs
107 |
108 | def get_cycle_length(self, cycles=0):
109 | cycles = max(1, cycles or self.cycle_limit)
110 | if self.cycle_mul == 1.0:
111 | t = self.t_initial * cycles
112 | else:
113 | t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
114 | return t + self.warmup_t if self.warmup_prefix else t
115 |
--------------------------------------------------------------------------------
/timm/layers/create_attn.py:
--------------------------------------------------------------------------------
1 | """ Attention Factory
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | import torch
6 | from functools import partial
7 |
8 | from .bottleneck_attn import BottleneckAttn
9 | from .cbam import CbamModule, LightCbamModule
10 | from .coord_attn import CoordAttn, EfficientLocalAttn, StripAttn, SimpleCoordAttn
11 | from .eca import EcaModule, CecaModule
12 | from .gather_excite import GatherExcite
13 | from .global_context import GlobalContext
14 | from .halo_attn import HaloAttn
15 | from .lambda_layer import LambdaLayer
16 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
17 | from .selective_kernel import SelectiveKernel
18 | from .split_attn import SplitAttn
19 | from .squeeze_excite import SEModule, EffectiveSEModule
20 |
21 |
22 | def get_attn(attn_type):
23 | if isinstance(attn_type, torch.nn.Module):
24 | return attn_type
25 | module_cls = None
26 | if attn_type:
27 | if isinstance(attn_type, str):
28 | attn_type = attn_type.lower()
29 | # Lightweight attention modules (channel and/or coarse spatial).
30 | # Typically added to existing network architecture blocks in addition to existing convolutions.
31 | if attn_type == 'se':
32 | module_cls = SEModule
33 | elif attn_type == 'ese':
34 | module_cls = EffectiveSEModule
35 | elif attn_type == 'eca':
36 | module_cls = EcaModule
37 | elif attn_type == 'ecam':
38 | module_cls = partial(EcaModule, use_mlp=True)
39 | elif attn_type == 'ceca':
40 | module_cls = CecaModule
41 | elif attn_type == 'ge':
42 | module_cls = GatherExcite
43 | elif attn_type == 'gc':
44 | module_cls = GlobalContext
45 | elif attn_type == 'gca':
46 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
47 | elif attn_type == 'cbam':
48 | module_cls = CbamModule
49 | elif attn_type == 'lcbam':
50 | module_cls = LightCbamModule
51 | elif attn_type == 'coord':
52 | module_cls = CoordAttn
53 | elif attn_type == 'scoord':
54 | module_cls = SimpleCoordAttn
55 | elif attn_type == 'ela':
56 | module_cls = EfficientLocalAttn
57 | elif attn_type == 'strip':
58 | module_cls = StripAttn
59 |
60 | # Attention / attention-like modules w/ significant params
61 | # Typically replace some of the existing workhorse convs in a network architecture.
62 | # All of these accept a stride argument and can spatially downsample the input.
63 | elif attn_type == 'sk':
64 | module_cls = SelectiveKernel
65 | elif attn_type == 'splat':
66 | module_cls = SplitAttn
67 |
68 | # Self-attention / attention-like modules w/ significant compute and/or params
69 | # Typically replace some of the existing workhorse convs in a network architecture.
70 | # All of these accept a stride argument and can spatially downsample the input.
71 | elif attn_type == 'lambda':
72 | return LambdaLayer
73 | elif attn_type == 'bottleneck':
74 | return BottleneckAttn
75 | elif attn_type == 'halo':
76 | return HaloAttn
77 | elif attn_type == 'nl':
78 | module_cls = NonLocalAttn
79 | elif attn_type == 'bat':
80 | module_cls = BatNonLocalAttn
81 |
82 | # Woops!
83 | else:
84 | assert False, "Invalid attn module (%s)" % attn_type
85 | elif isinstance(attn_type, bool):
86 | if attn_type:
87 | module_cls = SEModule
88 | else:
89 | module_cls = attn_type
90 | return module_cls
91 |
92 |
93 | def create_attn(attn_type, channels, **kwargs):
94 | module_cls = get_attn(attn_type)
95 | if module_cls is not None:
96 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
97 | return module_cls(channels, **kwargs)
98 | return None
99 |
--------------------------------------------------------------------------------