├── tests
├── __init__.py
├── test_layers.py
└── test_utils.py
├── docs
├── models
│ ├── .pages
│ └── .templates
│ │ ├── generate_readmes.py
│ │ ├── models
│ │ ├── spnasnet.md
│ │ ├── gloun-senet.md
│ │ ├── nasnet.md
│ │ ├── gloun-xception.md
│ │ ├── legacy-senet.md
│ │ ├── inception-v4.md
│ │ ├── skresnext.md
│ │ ├── pnasnet.md
│ │ ├── inception-resnet-v2.md
│ │ ├── csp-resnet.md
│ │ ├── csp-resnext.md
│ │ ├── fbnet.md
│ │ ├── res2next.md
│ │ ├── csp-darknet.md
│ │ ├── gloun-inception-v3.md
│ │ ├── inception-v3.md
│ │ ├── ese-vovnet.md
│ │ ├── tf-inception-v3.md
│ │ ├── wide-resnet.md
│ │ └── ensemble-adversarial.md
│ │ └── code_snippets.md
├── javascripts
│ └── tables.js
├── scripts.md
└── index.md
├── timm
├── version.py
├── data
│ ├── readers
│ │ ├── __init__.py
│ │ ├── shared_count.py
│ │ ├── reader.py
│ │ ├── class_map.py
│ │ ├── reader_factory.py
│ │ ├── img_extensions.py
│ │ ├── reader_hfds.py
│ │ ├── reader_image_tar.py
│ │ └── reader_image_folder.py
│ ├── constants.py
│ ├── __init__.py
│ ├── _info
│ │ ├── imagenet_r_indices.txt
│ │ ├── imagenet_a_indices.txt
│ │ ├── imagenet_a_synsets.txt
│ │ └── imagenet_r_synsets.txt
│ ├── real_labels.py
│ └── dataset_info.py
├── models
│ ├── hub.py
│ ├── factory.py
│ ├── features.py
│ ├── registry.py
│ ├── fx_features.py
│ ├── helpers.py
│ ├── layers
│ │ └── __init__.py
│ └── __init__.py
├── layers
│ ├── typing.py
│ ├── trace_utils.py
│ ├── linear.py
│ ├── helpers.py
│ ├── format.py
│ ├── grn.py
│ ├── blur_pool.py
│ ├── create_conv2d.py
│ ├── create_norm.py
│ ├── median_pool.py
│ ├── patch_dropout.py
│ ├── space_to_depth.py
│ ├── mixed_conv2d.py
│ ├── test_time_pool.py
│ ├── interpolate.py
│ ├── global_context.py
│ ├── filter_response_norm.py
│ ├── activations_jit.py
│ ├── pos_embed.py
│ ├── separable_conv.py
│ ├── padding.py
│ ├── pool2d_same.py
│ ├── split_attn.py
│ └── inplace_abn.py
├── utils
│ ├── random.py
│ ├── __init__.py
│ ├── clip_grad.py
│ ├── metrics.py
│ ├── log.py
│ ├── misc.py
│ ├── summary.py
│ ├── agc.py
│ ├── decay_batch.py
│ ├── cuda.py
│ └── jit.py
├── loss
│ ├── __init__.py
│ ├── cross_entropy.py
│ ├── jsd.py
│ └── binary_cross_entropy.py
├── __init__.py
├── scheduler
│ ├── __init__.py
│ ├── step_lr.py
│ └── multistep_lr.py
└── optim
│ ├── __init__.py
│ ├── sgdp.py
│ └── lookahead.py
├── .gitattributes
├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── config.yml
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ ├── delete_doc_comment_trigger.yml
│ ├── delete_doc_comment.yml
│ ├── upload_pr_documentation.yml
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ └── tests.yml
├── requirements.txt
├── requirements-dev.txt
├── MANIFEST.in
├── hubconf.py
├── setup.cfg
├── distributed_train.sh
├── hfdocs
├── source
│ ├── reference
│ │ ├── models.mdx
│ │ ├── data.mdx
│ │ ├── schedulers.mdx
│ │ └── optimizers.mdx
│ ├── index.mdx
│ ├── hf_hub.mdx
│ └── installation.mdx
└── README.md
├── requirements-docs.txt
├── pyproject.toml
├── model-index.yml
├── mkdocs.yml
├── .gitignore
├── setup.py
└── results
└── generate_csv_results.py
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/models/.pages:
--------------------------------------------------------------------------------
1 | title: Model Pages
--------------------------------------------------------------------------------
/timm/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.9.11'
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-documentation
2 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 | github: rwightman
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7
2 | torchvision
3 | pyyaml
4 | huggingface_hub
5 | safetensors>=0.2
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-timeout
3 | pytest-xdist
4 | pytest-forked
5 | expecttest
6 |
--------------------------------------------------------------------------------
/timm/data/readers/__init__.py:
--------------------------------------------------------------------------------
1 | from .reader_factory import create_reader
2 | from .img_extensions import *
3 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include timm/models/_pruned/*.txt
2 | include timm/data/_info/*.txt
3 | include timm/data/_info/*.json
4 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | dependencies = ['torch']
2 | import timm
3 | globals().update(timm.models._registry._model_entrypoints)
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [dist_conda]
2 |
3 | conda_name_differences = 'torch:pytorch'
4 | channels = pytorch
5 | noarch = True
6 |
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@"
5 |
6 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/models.mdx:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | [[autodoc]] timm.create_model
4 |
5 | [[autodoc]] timm.list_models
6 |
--------------------------------------------------------------------------------
/requirements-docs.txt:
--------------------------------------------------------------------------------
1 | mkdocs
2 | mkdocs-material
3 | mkdocs-redirects
4 | mdx_truly_sane_lists
5 | mkdocs-awesome-pages-plugin
6 |
--------------------------------------------------------------------------------
/timm/models/hub.py:
--------------------------------------------------------------------------------
1 | from ._hub import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/factory.py:
--------------------------------------------------------------------------------
1 | from ._factory import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/features.py:
--------------------------------------------------------------------------------
1 | from ._features import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/registry.py:
--------------------------------------------------------------------------------
1 | from ._registry import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
5 |
--------------------------------------------------------------------------------
/timm/models/fx_features.py:
--------------------------------------------------------------------------------
1 | from ._features_fx import *
2 |
3 | import warnings
4 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
5 |
--------------------------------------------------------------------------------
/docs/javascripts/tables.js:
--------------------------------------------------------------------------------
1 | app.location$.subscribe(function() {
2 | var tables = document.querySelectorAll("article table")
3 | tables.forEach(function(table) {
4 | new Tablesort(table)
5 | })
6 | })
--------------------------------------------------------------------------------
/timm/layers/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple, Type, Union
2 |
3 | import torch
4 |
5 |
6 | LayerType = Union[str, Callable, Type[torch.nn.Module]]
7 | PadType = Union[str, int, Tuple[int, int]]
8 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/data.mdx:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | [[autodoc]] timm.data.create_dataset
4 |
5 | [[autodoc]] timm.data.create_loader
6 |
7 | [[autodoc]] timm.data.create_transform
8 |
9 | [[autodoc]] timm.data.resolve_data_config
--------------------------------------------------------------------------------
/timm/utils/random.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def random_seed(seed=42, rank=0):
7 | torch.manual_seed(seed + rank)
8 | np.random.seed(seed + rank)
9 | random.seed(seed + rank)
10 |
--------------------------------------------------------------------------------
/timm/models/helpers.py:
--------------------------------------------------------------------------------
1 | from ._builder import *
2 | from ._helpers import *
3 | from ._manipulate import *
4 | from ._prune import *
5 |
6 | import warnings
7 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
8 |
--------------------------------------------------------------------------------
/timm/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
2 | from .binary_cross_entropy import BinaryCrossEntropy
3 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
4 | from .jsd import JsdCrossEntropy
5 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 | contact_links:
3 | - name: Community Discussions
4 | url: https://github.com/rwightman/pytorch-image-models/discussions
5 | about: Hparam request in issues will be ignored! Issues are for features and bugs. Questions can be asked in Discussions.
6 |
--------------------------------------------------------------------------------
/timm/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
3 | from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
4 | is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
5 |
--------------------------------------------------------------------------------
/hfdocs/README.md:
--------------------------------------------------------------------------------
1 | # Hugging Face Timm Docs
2 |
3 | ## Getting Started
4 |
5 | ```
6 | pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder
7 | pip install watchdog black
8 | ```
9 |
10 | ## Preview the Docs Locally
11 |
12 | ```
13 | doc-builder preview timm hfdocs/source
14 | ```
15 |
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment_trigger.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment trigger
2 |
3 | on:
4 | pull_request:
5 | types: [ closed ]
6 |
7 |
8 | jobs:
9 | delete:
10 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
11 | with:
12 | pr_number: ${{ github.event.number }}
--------------------------------------------------------------------------------
/timm/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .cosine_lr import CosineLRScheduler
2 | from .multistep_lr import MultiStepLRScheduler
3 | from .plateau_lr import PlateauLRScheduler
4 | from .poly_lr import PolyLRScheduler
5 | from .step_lr import StepLRScheduler
6 | from .tanh_lr import TanhLRScheduler
7 |
8 | from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs
9 |
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Delete doc comment trigger"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | delete:
11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
12 | secrets:
13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/timm/data/readers/shared_count.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Value
2 |
3 |
4 | class SharedCount:
5 | def __init__(self, epoch: int = 0):
6 | self.shared_epoch = Value('i', epoch)
7 |
8 | @property
9 | def value(self):
10 | return self.shared_epoch.value
11 |
12 | @value.setter
13 | def value(self, epoch):
14 | self.shared_epoch.value = epoch
15 |
--------------------------------------------------------------------------------
/timm/layers/trace_utils.py:
--------------------------------------------------------------------------------
1 | try:
2 | from torch import _assert
3 | except ImportError:
4 | def _assert(condition: bool, message: str):
5 | assert condition, message
6 |
7 |
8 | def _float_to_int(x: float) -> int:
9 | """
10 | Symbolic tracing helper to substitute for inbuilt `int`.
11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy`
12 | """
13 | return int(x)
14 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: timm
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/timm/data/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_CROP_PCT = 0.875
2 | DEFAULT_CROP_MODE = 'center'
3 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
4 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
5 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
6 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
7 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
8 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
9 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
10 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
11 |
--------------------------------------------------------------------------------
/timm/data/readers/reader.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class Reader:
5 | def __init__(self):
6 | pass
7 |
8 | @abstractmethod
9 | def _filename(self, index, basename=False, absolute=False):
10 | pass
11 |
12 | def filename(self, index, basename=False, absolute=False):
13 | return self._filename(index, basename=basename, absolute=absolute)
14 |
15 | def filenames(self, basename=False, absolute=False):
16 | return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17 |
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.sha }}
15 | package: pytorch-image-models
16 | package_name: timm
17 | path_to_docs: pytorch-image-models/hfdocs/source
18 | version_tag_suffix: ""
19 | secrets:
20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
21 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | markers = [
3 | "base: marker for model tests using the basic setup",
4 | "cfg: marker for model tests checking the config",
5 | "torchscript: marker for model tests using torchscript",
6 | "features: marker for model tests checking feature extraction",
7 | "fxforward: marker for model tests using torch fx (only forward)",
8 | "fxbackward: marker for model tests using torch fx (only backward)",
9 | ]
10 |
11 | [tool.black]
12 | line-length = 120
13 | target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
14 | skip-string-normalization = true
15 |
--------------------------------------------------------------------------------
/timm/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .adabelief import AdaBelief
2 | from .adafactor import Adafactor
3 | from .adahessian import Adahessian
4 | from .adamp import AdamP
5 | from .adamw import AdamW
6 | from .adan import Adan
7 | from .lamb import Lamb
8 | from .lars import Lars
9 | from .lookahead import Lookahead
10 | from .madgrad import MADGRAD
11 | from .nadam import Nadam
12 | from .nvnovograd import NvNovoGrad
13 | from .radam import RAdam
14 | from .rmsprop_tf import RMSpropTF
15 | from .sgdp import SGDP
16 | from .lion import Lion
17 | from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
18 |
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 |
6 | concurrency:
7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.event.pull_request.head.sha }}
15 | pr_number: ${{ github.event.number }}
16 | package: pytorch-image-models
17 | package_name: timm
18 | path_to_docs: pytorch-image-models/hfdocs/source
19 | version_tag_suffix: ""
20 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/schedulers.mdx:
--------------------------------------------------------------------------------
1 | # Learning Rate Schedulers
2 |
3 | This page contains the API reference documentation for learning rate schedulers included in `timm`.
4 |
5 | ## Schedulers
6 |
7 | ### Factory functions
8 |
9 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler
10 | [[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2
11 |
12 | ### Scheduler Classes
13 |
14 | [[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler
15 | [[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler
16 | [[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler
17 | [[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler
18 | [[autodoc]] timm.scheduler.step_lr.StepLRScheduler
19 | [[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler
20 |
--------------------------------------------------------------------------------
/model-index.yml:
--------------------------------------------------------------------------------
1 | Import:
2 | - ./docs/models/*.md
3 | Library:
4 | Name: PyTorch Image Models
5 | Headline: PyTorch image models, scripts, pretrained weights
6 | Website: https://rwightman.github.io/pytorch-image-models/
7 | Repository: https://github.com/rwightman/pytorch-image-models
8 | Docs: https://rwightman.github.io/pytorch-image-models/
9 | README: "# PyTorch Image Models\r\n\r\nPyTorch Image Models (TIMM) is a library\
10 | \ for state-of-the-art image classification. With this library you can:\r\n\r\n\
11 | - Choose from 300+ pre-trained state-of-the-art image classification models.\r\
12 | \n- Train models afresh on research datasets such as ImageNet using provided scripts.\r\
13 | \n- Finetune pre-trained models on your own datasets, including the latest cutting\
14 | \ edge models."
15 |
--------------------------------------------------------------------------------
/timm/layers/linear.py:
--------------------------------------------------------------------------------
1 | """ Linear layer (alternate definition)
2 | """
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 |
7 |
8 | class Linear(nn.Linear):
9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10 |
11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13 | """
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | if torch.jit.is_scripting():
16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18 | else:
19 | return F.linear(input, self.weight, self.bias)
20 |
--------------------------------------------------------------------------------
/timm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .agc import adaptive_clip_grad
2 | from .checkpoint_saver import CheckpointSaver
3 | from .clip_grad import dispatch_clip_grad
4 | from .cuda import ApexScaler, NativeScaler
5 | from .decay_batch import decay_batch_step, check_batch_size_retry
6 | from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
7 | world_info_from_env, is_distributed_env, is_primary
8 | from .jit import set_jit_legacy, set_jit_fuser
9 | from .log import setup_default_logging, FormatterNoInfo
10 | from .metrics import AverageMeter, accuracy
11 | from .misc import natural_key, add_bool_arg, ParseKwargs
12 | from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
13 | from .model_ema import ModelEma, ModelEmaV2
14 | from .random import random_seed
15 | from .summary import update_summary, get_outdir
16 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project. Hparam requests, training help are not feature requests.
4 | The discussion forum is available for asking questions or seeking help from the community.
5 | title: "[FEATURE] Feature title..."
6 | labels: enhancement
7 | assignees: ''
8 |
9 | ---
10 |
11 | **Is your feature request related to a problem? Please describe.**
12 | A clear and concise description of what the problem is.
13 |
14 | **Describe the solution you'd like**
15 | A clear and concise description of what you want to happen.
16 |
17 | **Describe alternatives you've considered**
18 | A clear and concise description of any alternative solutions or features you've considered.
19 |
20 | **Additional context**
21 | Add any other context or screenshots about the feature request here.
22 |
--------------------------------------------------------------------------------
/timm/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2 | rand_augment_transform, auto_augment_transform
3 | from .config import resolve_data_config, resolve_model_data_config
4 | from .constants import *
5 | from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6 | from .dataset_factory import create_dataset
7 | from .dataset_info import DatasetInfo, CustomDatasetInfo
8 | from .imagenet_info import ImageNetInfo, infer_imagenet_subset
9 | from .loader import create_loader
10 | from .mixup import Mixup, FastCollateMixup
11 | from .readers import create_reader
12 | from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
13 | from .real_labels import RealLabelsImagenet
14 | from .transforms import *
15 | from .transforms_factory import create_transform
16 |
--------------------------------------------------------------------------------
/timm/utils/clip_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from timm.utils.agc import adaptive_clip_grad
4 |
5 |
6 | def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
7 | """ Dispatch to gradient clipping method
8 |
9 | Args:
10 | parameters (Iterable): model parameters to clip
11 | value (float): clipping value/factor/norm, mode dependant
12 | mode (str): clipping mode, one of 'norm', 'value', 'agc'
13 | norm_type (float): p-norm, default 2.0
14 | """
15 | if mode == 'norm':
16 | torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
17 | elif mode == 'value':
18 | torch.nn.utils.clip_grad_value_(parameters, value)
19 | elif mode == 'agc':
20 | adaptive_clip_grad(parameters, value, norm_type=norm_type)
21 | else:
22 | assert False, f"Unknown clip mode ({mode})."
23 |
24 |
--------------------------------------------------------------------------------
/hfdocs/source/reference/optimizers.mdx:
--------------------------------------------------------------------------------
1 | # Optimization
2 |
3 | This page contains the API reference documentation for learning rate optimizers included in `timm`.
4 |
5 | ## Optimizers
6 |
7 | ### Factory functions
8 |
9 | [[autodoc]] timm.optim.optim_factory.create_optimizer
10 | [[autodoc]] timm.optim.optim_factory.create_optimizer_v2
11 |
12 | ### Optimizer Classes
13 |
14 | [[autodoc]] timm.optim.adabelief.AdaBelief
15 | [[autodoc]] timm.optim.adafactor.Adafactor
16 | [[autodoc]] timm.optim.adahessian.Adahessian
17 | [[autodoc]] timm.optim.adamp.AdamP
18 | [[autodoc]] timm.optim.adamw.AdamW
19 | [[autodoc]] timm.optim.lamb.Lamb
20 | [[autodoc]] timm.optim.lars.Lars
21 | [[autodoc]] timm.optim.lookahead.Lookahead
22 | [[autodoc]] timm.optim.madgrad.MADGRAD
23 | [[autodoc]] timm.optim.nadam.Nadam
24 | [[autodoc]] timm.optim.nvnovograd.NvNovoGrad
25 | [[autodoc]] timm.optim.radam.RAdam
26 | [[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
27 | [[autodoc]] timm.optim.sgdp.SGDP
28 |
--------------------------------------------------------------------------------
/timm/data/readers/class_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 |
5 | def load_class_map(map_or_filename, root=''):
6 | if isinstance(map_or_filename, dict):
7 | assert dict, 'class_map dict must be non-empty'
8 | return map_or_filename
9 | class_map_path = map_or_filename
10 | if not os.path.exists(class_map_path):
11 | class_map_path = os.path.join(root, class_map_path)
12 | assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
13 | class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
14 | if class_map_ext == '.txt':
15 | with open(class_map_path) as f:
16 | class_to_idx = {v.strip(): k for k, v in enumerate(f)}
17 | elif class_map_ext == '.pkl':
18 | with open(class_map_path, 'rb') as f:
19 | class_to_idx = pickle.load(f)
20 | else:
21 | assert False, f'Unsupported class map file extension ({class_map_ext}).'
22 | return class_to_idx
23 |
24 |
--------------------------------------------------------------------------------
/timm/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """ Eval metrics and related
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 |
7 | class AverageMeter:
8 | """Computes and stores the average and current value"""
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the accuracy over the k top predictions for the specified values of k"""
27 | maxk = min(max(topk), output.size()[1])
28 | batch_size = target.size(0)
29 | _, pred = output.topk(maxk, 1, True, True)
30 | pred = pred.t()
31 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
32 | return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a bug report to help us improve. Issues are for reporting bugs or requesting
4 | features, the discussion forum is available for asking questions or seeking help
5 | from the community.
6 | title: "[BUG] Issue title..."
7 | labels: bug
8 | assignees: rwightman
9 |
10 | ---
11 |
12 | **Describe the bug**
13 | A clear and concise description of what the bug is.
14 |
15 | **To Reproduce**
16 | Steps to reproduce the behavior:
17 | 1.
18 | 2.
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. Windows 10, Ubuntu 18.04]
28 | - This repository version [e.g. pip 0.3.1 or commit ref]
29 | - PyTorch version w/ CUDA/cuDNN [e.g. from `conda list`, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0]
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/timm/utils/log.py:
--------------------------------------------------------------------------------
1 | """ Logging helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import logging
6 | import logging.handlers
7 |
8 |
9 | class FormatterNoInfo(logging.Formatter):
10 | def __init__(self, fmt='%(levelname)s: %(message)s'):
11 | logging.Formatter.__init__(self, fmt)
12 |
13 | def format(self, record):
14 | if record.levelno == logging.INFO:
15 | return str(record.getMessage())
16 | return logging.Formatter.format(self, record)
17 |
18 |
19 | def setup_default_logging(default_level=logging.INFO, log_path=''):
20 | console_handler = logging.StreamHandler()
21 | console_handler.setFormatter(FormatterNoInfo())
22 | logging.root.addHandler(console_handler)
23 | logging.root.setLevel(default_level)
24 | if log_path:
25 | file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
26 | file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
27 | file_handler.setFormatter(file_formatter)
28 | logging.root.addHandler(file_handler)
29 |
--------------------------------------------------------------------------------
/timm/utils/misc.py:
--------------------------------------------------------------------------------
1 | """ Misc utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import argparse
6 | import ast
7 | import re
8 |
9 |
10 | def natural_key(string_):
11 | """See http://www.codinghorror.com/blog/archives/001018.html"""
12 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
13 |
14 |
15 | def add_bool_arg(parser, name, default=False, help=''):
16 | dest_name = name.replace('-', '_')
17 | group = parser.add_mutually_exclusive_group(required=False)
18 | group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
19 | group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
20 | parser.set_defaults(**{dest_name: default})
21 |
22 |
23 | class ParseKwargs(argparse.Action):
24 | def __call__(self, parser, namespace, values, option_string=None):
25 | kw = {}
26 | for value in values:
27 | key, value = value.split('=')
28 | try:
29 | kw[key] = ast.literal_eval(value)
30 | except ValueError:
31 | kw[key] = str(value) # fallback to string (avoid need to escape on command line)
32 | setattr(namespace, self.dest, kw)
33 |
--------------------------------------------------------------------------------
/timm/layers/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13 | return tuple(x)
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
32 |
33 |
34 | def extend_tuple(x, n):
35 | # pads a tuple to specified n by padding with last value
36 | if not isinstance(x, (tuple, list)):
37 | x = (x,)
38 | else:
39 | x = tuple(x)
40 | pad_n = n - len(x)
41 | if pad_n <= 0:
42 | return x[:n]
43 | return x + (x[-1],) * pad_n
44 |
--------------------------------------------------------------------------------
/timm/loss/cross_entropy.py:
--------------------------------------------------------------------------------
1 | """ Cross Entropy w/ smoothing or soft targets
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class LabelSmoothingCrossEntropy(nn.Module):
12 | """ NLL loss with label smoothing.
13 | """
14 | def __init__(self, smoothing=0.1):
15 | super(LabelSmoothingCrossEntropy, self).__init__()
16 | assert smoothing < 1.0
17 | self.smoothing = smoothing
18 | self.confidence = 1. - smoothing
19 |
20 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
21 | logprobs = F.log_softmax(x, dim=-1)
22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
23 | nll_loss = nll_loss.squeeze(1)
24 | smooth_loss = -logprobs.mean(dim=-1)
25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
26 | return loss.mean()
27 |
28 |
29 | class SoftTargetCrossEntropy(nn.Module):
30 |
31 | def __init__(self):
32 | super(SoftTargetCrossEntropy, self).__init__()
33 |
34 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
36 | return loss.mean()
37 |
--------------------------------------------------------------------------------
/timm/layers/format.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Union
3 |
4 | import torch
5 |
6 |
7 | class Format(str, Enum):
8 | NCHW = 'NCHW'
9 | NHWC = 'NHWC'
10 | NCL = 'NCL'
11 | NLC = 'NLC'
12 |
13 |
14 | FormatT = Union[str, Format]
15 |
16 |
17 | def get_spatial_dim(fmt: FormatT):
18 | fmt = Format(fmt)
19 | if fmt is Format.NLC:
20 | dim = (1,)
21 | elif fmt is Format.NCL:
22 | dim = (2,)
23 | elif fmt is Format.NHWC:
24 | dim = (1, 2)
25 | else:
26 | dim = (2, 3)
27 | return dim
28 |
29 |
30 | def get_channel_dim(fmt: FormatT):
31 | fmt = Format(fmt)
32 | if fmt is Format.NHWC:
33 | dim = 3
34 | elif fmt is Format.NLC:
35 | dim = 2
36 | else:
37 | dim = 1
38 | return dim
39 |
40 |
41 | def nchw_to(x: torch.Tensor, fmt: Format):
42 | if fmt == Format.NHWC:
43 | x = x.permute(0, 2, 3, 1)
44 | elif fmt == Format.NLC:
45 | x = x.flatten(2).transpose(1, 2)
46 | elif fmt == Format.NCL:
47 | x = x.flatten(2)
48 | return x
49 |
50 |
51 | def nhwc_to(x: torch.Tensor, fmt: Format):
52 | if fmt == Format.NCHW:
53 | x = x.permute(0, 3, 1, 2)
54 | elif fmt == Format.NLC:
55 | x = x.flatten(1, 2)
56 | elif fmt == Format.NCL:
57 | x = x.flatten(1, 2).transpose(1, 2)
58 | return x
59 |
--------------------------------------------------------------------------------
/timm/layers/grn.py:
--------------------------------------------------------------------------------
1 | """ Global Response Normalization Module
2 |
3 | Based on the GRN layer presented in
4 | `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
5 |
6 | This implementation
7 | * works for both NCHW and NHWC tensor layouts
8 | * uses affine param names matching existing torch norm layers
9 | * slightly improves eager mode performance via fused addcmul
10 |
11 | Hacked together by / Copyright 2023 Ross Wightman
12 | """
13 |
14 | import torch
15 | from torch import nn as nn
16 |
17 |
18 | class GlobalResponseNorm(nn.Module):
19 | """ Global Response Normalization layer
20 | """
21 | def __init__(self, dim, eps=1e-6, channels_last=True):
22 | super().__init__()
23 | self.eps = eps
24 | if channels_last:
25 | self.spatial_dim = (1, 2)
26 | self.channel_dim = -1
27 | self.wb_shape = (1, 1, 1, -1)
28 | else:
29 | self.spatial_dim = (2, 3)
30 | self.channel_dim = 1
31 | self.wb_shape = (1, -1, 1, 1)
32 |
33 | self.weight = nn.Parameter(torch.zeros(dim))
34 | self.bias = nn.Parameter(torch.zeros(dim))
35 |
36 | def forward(self, x):
37 | x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
38 | x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
39 | return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)
40 |
--------------------------------------------------------------------------------
/timm/utils/summary.py:
--------------------------------------------------------------------------------
1 | """ Summary utilities
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import csv
6 | import os
7 | from collections import OrderedDict
8 | try:
9 | import wandb
10 | except ImportError:
11 | pass
12 |
13 |
14 | def get_outdir(path, *paths, inc=False):
15 | outdir = os.path.join(path, *paths)
16 | if not os.path.exists(outdir):
17 | os.makedirs(outdir)
18 | elif inc:
19 | count = 1
20 | outdir_inc = outdir + '-' + str(count)
21 | while os.path.exists(outdir_inc):
22 | count = count + 1
23 | outdir_inc = outdir + '-' + str(count)
24 | assert count < 100
25 | outdir = outdir_inc
26 | os.makedirs(outdir)
27 | return outdir
28 |
29 |
30 | def update_summary(
31 | epoch,
32 | train_metrics,
33 | eval_metrics,
34 | filename,
35 | lr=None,
36 | write_header=False,
37 | log_wandb=False,
38 | ):
39 | rowd = OrderedDict(epoch=epoch)
40 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
41 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
42 | if lr is not None:
43 | rowd['lr'] = lr
44 | if log_wandb:
45 | wandb.log(rowd)
46 | with open(filename, mode='a') as cf:
47 | dw = csv.DictWriter(cf, fieldnames=rowd.keys())
48 | if write_header: # first iteration (epoch == 1 can't be used)
49 | dw.writeheader()
50 | dw.writerow(rowd)
51 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .reader_image_folder import ReaderImageFolder
4 | from .reader_image_in_tar import ReaderImageInTar
5 |
6 |
7 | def create_reader(name, root, split='train', **kwargs):
8 | name = name.lower()
9 | name = name.split('/', 1)
10 | prefix = ''
11 | if len(name) > 1:
12 | prefix = name[0]
13 | name = name[-1]
14 |
15 | # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
16 | # explicitly select other options shortly
17 | if prefix == 'hfds':
18 | from .reader_hfds import ReaderHfds # defer tensorflow import
19 | reader = ReaderHfds(root, name, split=split, **kwargs)
20 | elif prefix == 'tfds':
21 | from .reader_tfds import ReaderTfds # defer tensorflow import
22 | reader = ReaderTfds(root, name, split=split, **kwargs)
23 | elif prefix == 'wds':
24 | from .reader_wds import ReaderWds
25 | kwargs.pop('download', False)
26 | reader = ReaderWds(root, name, split=split, **kwargs)
27 | else:
28 | assert os.path.exists(root)
29 | # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
30 | # FIXME support split here or in reader?
31 | if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
32 | reader = ReaderImageInTar(root, **kwargs)
33 | else:
34 | reader = ReaderImageFolder(root, **kwargs)
35 | return reader
36 |
--------------------------------------------------------------------------------
/hfdocs/source/index.mdx:
--------------------------------------------------------------------------------
1 | # timm
2 |
3 |
4 |
5 | `timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.
6 |
7 | It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use.
8 |
9 | Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library.
10 |
11 |
23 |
--------------------------------------------------------------------------------
/timm/data/readers/img_extensions.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | __all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4 |
5 |
6 | IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7 | _IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8 |
9 |
10 | def _set_extensions(extensions):
11 | global IMG_EXTENSIONS
12 | global _IMG_EXTENSIONS_SET
13 | dedupe = set() # NOTE de-duping tuple while keeping original order
14 | IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15 | _IMG_EXTENSIONS_SET = set(extensions)
16 |
17 |
18 | def _valid_extension(x: str):
19 | return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20 |
21 |
22 | def is_img_extension(ext):
23 | return ext in _IMG_EXTENSIONS_SET
24 |
25 |
26 | def get_img_extensions(as_set=False):
27 | return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28 |
29 |
30 | def set_img_extensions(extensions):
31 | assert len(extensions)
32 | for x in extensions:
33 | assert _valid_extension(x)
34 | _set_extensions(extensions)
35 |
36 |
37 | def add_img_extensions(ext):
38 | if not isinstance(ext, (list, tuple, set)):
39 | ext = (ext,)
40 | for x in ext:
41 | assert _valid_extension(x)
42 | extensions = IMG_EXTENSIONS + tuple(ext)
43 | _set_extensions(extensions)
44 |
45 |
46 | def del_img_extensions(ext):
47 | if not isinstance(ext, (list, tuple, set)):
48 | ext = (ext,)
49 | extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50 | _set_extensions(extensions)
51 |
--------------------------------------------------------------------------------
/timm/loss/jsd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .cross_entropy import LabelSmoothingCrossEntropy
6 |
7 |
8 | class JsdCrossEntropy(nn.Module):
9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss
10 |
11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
13 | https://arxiv.org/abs/1912.02781
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
18 | super().__init__()
19 | self.num_splits = num_splits
20 | self.alpha = alpha
21 | if smoothing is not None and smoothing > 0:
22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
23 | else:
24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
25 |
26 | def __call__(self, output, target):
27 | split_size = output.shape[0] // self.num_splits
28 | assert split_size * self.num_splits == output.shape[0]
29 | logits_split = torch.split(output, split_size)
30 |
31 | # Cross-entropy is only computed on clean images
32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
33 | probs = [F.softmax(logits, dim=1) for logits in logits_split]
34 |
35 | # Clamp mixture distribution to avoid exploding KL divergence
36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
37 | loss += self.alpha * sum([F.kl_div(
38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
39 | return loss
40 |
--------------------------------------------------------------------------------
/timm/layers/blur_pool.py:
--------------------------------------------------------------------------------
1 | """
2 | BlurPool layer inspired by
3 | - Kornia's Max_BlurPool2d
4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5 |
6 | Hacked together by Chris Ha and Ross Wightman
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 | from .padding import get_padding
14 |
15 |
16 | class BlurPool2d(nn.Module):
17 | r"""Creates a module that computes blurs and downsample a given feature map.
18 | See :cite:`zhang2019shiftinvar` for more details.
19 | Corresponds to the Downsample class, which does blurring and subsampling
20 |
21 | Args:
22 | channels = Number of input channels
23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
24 | stride (int): downsampling filter stride
25 |
26 | Returns:
27 | torch.Tensor: the transformed tensor.
28 | """
29 | def __init__(self, channels, filt_size=3, stride=2) -> None:
30 | super(BlurPool2d, self).__init__()
31 | assert filt_size > 1
32 | self.channels = channels
33 | self.filt_size = filt_size
34 | self.stride = stride
35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
38 | self.register_buffer('filt', blur_filter, persistent=False)
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | x = F.pad(x, self.padding, 'reflect')
42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
43 |
--------------------------------------------------------------------------------
/timm/layers/create_conv2d.py:
--------------------------------------------------------------------------------
1 | """ Create Conv2d Factory Method
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | from .mixed_conv2d import MixedConv2d
7 | from .cond_conv2d import CondConv2d
8 | from .conv2d_same import create_conv2d_pad
9 |
10 |
11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
12 | """ Select a 2d convolution implementation based on arguments
13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
14 |
15 | Used extensively by EfficientNet, MobileNetv3 and related networks.
16 | """
17 | if isinstance(kernel_size, list):
18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
19 | if 'groups' in kwargs:
20 | groups = kwargs.pop('groups')
21 | if groups == in_channels:
22 | kwargs['depthwise'] = True
23 | else:
24 | assert groups == 1
25 | # We're going to use only lists for defining the MixedConv2d kernel groups,
26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
28 | else:
29 | depthwise = kwargs.pop('depthwise', False)
30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
31 | groups = in_channels if depthwise else kwargs.pop('groups', 1)
32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
34 | else:
35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
36 | return m
37 |
--------------------------------------------------------------------------------
/timm/utils/agc.py:
--------------------------------------------------------------------------------
1 | """ Adaptive Gradient Clipping
2 |
3 | An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
4 |
5 | @article{brock2021high,
6 | author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
7 | title={High-Performance Large-Scale Image Recognition Without Normalization},
8 | journal={arXiv preprint arXiv:},
9 | year={2021}
10 | }
11 |
12 | Code references:
13 | * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
14 | * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
15 |
16 | Hacked together by / Copyright 2021 Ross Wightman
17 | """
18 | import torch
19 |
20 |
21 | def unitwise_norm(x, norm_type=2.0):
22 | if x.ndim <= 1:
23 | return x.norm(norm_type)
24 | else:
25 | # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
26 | # might need special cases for other weights (possibly MHA) where this may not be true
27 | return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
28 |
29 |
30 | def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
31 | if isinstance(parameters, torch.Tensor):
32 | parameters = [parameters]
33 | for p in parameters:
34 | if p.grad is None:
35 | continue
36 | p_data = p.detach()
37 | g_data = p.grad.detach()
38 | max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
39 | grad_norm = unitwise_norm(g_data, norm_type=norm_type)
40 | clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
41 | new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
42 | p.grad.detach().copy_(new_grads)
43 |
--------------------------------------------------------------------------------
/timm/layers/create_norm.py:
--------------------------------------------------------------------------------
1 | """ Norm Layer Factory
2 |
3 | Create norm modules by string (to mirror create_act and creat_norm-act fns)
4 |
5 | Copyright 2022 Ross Wightman
6 | """
7 | import functools
8 | import types
9 | from typing import Type
10 |
11 | import torch.nn as nn
12 |
13 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
14 | from torchvision.ops.misc import FrozenBatchNorm2d
15 |
16 | _NORM_MAP = dict(
17 | batchnorm=nn.BatchNorm2d,
18 | batchnorm2d=nn.BatchNorm2d,
19 | batchnorm1d=nn.BatchNorm1d,
20 | groupnorm=GroupNorm,
21 | groupnorm1=GroupNorm1,
22 | layernorm=LayerNorm,
23 | layernorm2d=LayerNorm2d,
24 | rmsnorm=RmsNorm,
25 | frozenbatchnorm2d=FrozenBatchNorm2d,
26 | )
27 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
28 |
29 |
30 | def create_norm_layer(layer_name, num_features, **kwargs):
31 | layer = get_norm_layer(layer_name)
32 | layer_instance = layer(num_features, **kwargs)
33 | return layer_instance
34 |
35 |
36 | def get_norm_layer(norm_layer):
37 | if norm_layer is None:
38 | return None
39 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
40 | norm_kwargs = {}
41 |
42 | # unbind partial fn, so args can be rebound later
43 | if isinstance(norm_layer, functools.partial):
44 | norm_kwargs.update(norm_layer.keywords)
45 | norm_layer = norm_layer.func
46 |
47 | if isinstance(norm_layer, str):
48 | if not norm_layer:
49 | return None
50 | layer_name = norm_layer.replace('_', '')
51 | norm_layer = _NORM_MAP[layer_name]
52 | else:
53 | norm_layer = norm_layer
54 |
55 | if norm_kwargs:
56 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
57 | return norm_layer
58 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: 'Pytorch Image Models'
2 | site_description: 'Pretained Image Recognition Models'
3 | repo_name: 'rwightman/pytorch-image-models'
4 | repo_url: 'https://github.com/rwightman/pytorch-image-models'
5 | nav:
6 | - index.md
7 | - models.md
8 | - ... | models/*.md
9 | - results.md
10 | - scripts.md
11 | - training_hparam_examples.md
12 | - feature_extraction.md
13 | - changes.md
14 | - archived_changes.md
15 | theme:
16 | name: 'material'
17 | feature:
18 | tabs: false
19 | extra_javascript:
20 | - 'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML'
21 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js
22 | - javascripts/tables.js
23 | markdown_extensions:
24 | - codehilite:
25 | linenums: true
26 | - admonition
27 | - pymdownx.arithmatex
28 | - pymdownx.betterem:
29 | smart_enable: all
30 | - pymdownx.caret
31 | - pymdownx.critic
32 | - pymdownx.details
33 | - pymdownx.emoji:
34 | emoji_generator: !!python/name:pymdownx.emoji.to_svg
35 | - pymdownx.inlinehilite
36 | - pymdownx.magiclink
37 | - pymdownx.mark
38 | - pymdownx.smartsymbols
39 | - pymdownx.superfences
40 | - pymdownx.tasklist:
41 | custom_checkbox: true
42 | - pymdownx.tilde
43 | - mdx_truly_sane_lists
44 | plugins:
45 | - search
46 | - awesome-pages
47 | - redirects:
48 | redirect_maps:
49 | 'index.md': 'https://huggingface.co/docs/timm/index'
50 | 'models.md': 'https://huggingface.co/docs/timm/models'
51 | 'results.md': 'https://huggingface.co/docs/timm/results'
52 | 'scripts.md': 'https://huggingface.co/docs/timm/training_script'
53 | 'training_hparam_examples.md': 'https://huggingface.co/docs/timm/training_script#training-examples'
54 | 'feature_extraction.md': 'https://huggingface.co/docs/timm/feature_extraction'
55 |
--------------------------------------------------------------------------------
/timm/utils/decay_batch.py:
--------------------------------------------------------------------------------
1 | """ Batch size decay and retry helpers.
2 |
3 | Copyright 2022 Ross Wightman
4 | """
5 | import math
6 |
7 |
8 | def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
9 | """ power of two batch-size decay with intra steps
10 |
11 | Decay by stepping between powers of 2:
12 | * determine power-of-2 floor of current batch size (base batch size)
13 | * divide above value by num_intra_steps to determine step size
14 | * floor batch_size to nearest multiple of step_size (from base batch size)
15 | Examples:
16 | num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
17 | num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
18 | num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
19 | num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
20 | """
21 | if batch_size <= 1:
22 | # return 0 for stopping value so easy to use in loop
23 | return 0
24 | base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
25 | step_size = max(base_batch_size // num_intra_steps, 1)
26 | batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
27 | if no_odd and batch_size % 2:
28 | batch_size -= 1
29 | return batch_size
30 |
31 |
32 | def check_batch_size_retry(error_str):
33 | """ check failure error string for conditions where batch decay retry should not be attempted
34 | """
35 | error_str = error_str.lower()
36 | if 'required rank' in error_str:
37 | # Errors involving phrase 'required rank' typically happen when a conv is used that's
38 | # not compatible with channels_last memory format.
39 | return False
40 | if 'illegal' in error_str:
41 | # 'Illegal memory access' errors in CUDA typically leave process in unusable state
42 | return False
43 | return True
44 |
--------------------------------------------------------------------------------
/docs/scripts.md:
--------------------------------------------------------------------------------
1 | # Scripts
2 | A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release.
3 |
4 | The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on
5 | [NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples).
6 |
7 | ## Training Script
8 |
9 | The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
10 |
11 | To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
12 |
13 | `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4`
14 |
15 | NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed.
16 |
17 | ## Validation / Inference Scripts
18 |
19 | Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.
20 |
21 | To validate with the model's pretrained weights (if they exist):
22 |
23 | `python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained`
24 |
25 | To run inference from a checkpoint:
26 |
27 | `python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar`
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_r_indices.txt:
--------------------------------------------------------------------------------
1 | 1
2 | 2
3 | 4
4 | 6
5 | 8
6 | 9
7 | 11
8 | 13
9 | 22
10 | 23
11 | 26
12 | 29
13 | 31
14 | 39
15 | 47
16 | 63
17 | 71
18 | 76
19 | 79
20 | 84
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 100
27 | 105
28 | 107
29 | 113
30 | 122
31 | 125
32 | 130
33 | 132
34 | 144
35 | 145
36 | 147
37 | 148
38 | 150
39 | 151
40 | 155
41 | 160
42 | 161
43 | 162
44 | 163
45 | 171
46 | 172
47 | 178
48 | 187
49 | 195
50 | 199
51 | 203
52 | 207
53 | 208
54 | 219
55 | 231
56 | 232
57 | 234
58 | 235
59 | 242
60 | 245
61 | 247
62 | 250
63 | 251
64 | 254
65 | 259
66 | 260
67 | 263
68 | 265
69 | 267
70 | 269
71 | 276
72 | 277
73 | 281
74 | 288
75 | 289
76 | 291
77 | 292
78 | 293
79 | 296
80 | 299
81 | 301
82 | 308
83 | 309
84 | 310
85 | 311
86 | 314
87 | 315
88 | 319
89 | 323
90 | 327
91 | 330
92 | 334
93 | 335
94 | 337
95 | 338
96 | 340
97 | 341
98 | 344
99 | 347
100 | 353
101 | 355
102 | 361
103 | 362
104 | 365
105 | 366
106 | 367
107 | 368
108 | 372
109 | 388
110 | 390
111 | 393
112 | 397
113 | 401
114 | 407
115 | 413
116 | 414
117 | 425
118 | 428
119 | 430
120 | 435
121 | 437
122 | 441
123 | 447
124 | 448
125 | 457
126 | 462
127 | 463
128 | 469
129 | 470
130 | 471
131 | 472
132 | 476
133 | 483
134 | 487
135 | 515
136 | 546
137 | 555
138 | 558
139 | 570
140 | 579
141 | 583
142 | 587
143 | 593
144 | 594
145 | 596
146 | 609
147 | 613
148 | 617
149 | 621
150 | 629
151 | 637
152 | 657
153 | 658
154 | 701
155 | 717
156 | 724
157 | 763
158 | 768
159 | 774
160 | 776
161 | 779
162 | 780
163 | 787
164 | 805
165 | 812
166 | 815
167 | 820
168 | 824
169 | 833
170 | 847
171 | 852
172 | 866
173 | 875
174 | 883
175 | 889
176 | 895
177 | 907
178 | 928
179 | 931
180 | 932
181 | 933
182 | 934
183 | 936
184 | 937
185 | 943
186 | 945
187 | 947
188 | 948
189 | 949
190 | 951
191 | 953
192 | 954
193 | 957
194 | 963
195 | 965
196 | 967
197 | 980
198 | 981
199 | 983
200 | 988
201 |
--------------------------------------------------------------------------------
/timm/layers/median_pool.py:
--------------------------------------------------------------------------------
1 | """ Median Pool
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .helpers import to_2tuple, to_4tuple
7 |
8 |
9 | class MedianPool2d(nn.Module):
10 | """ Median pool (usable as median filter when stride=1) module.
11 |
12 | Args:
13 | kernel_size: size of pooling kernel, int or 2-tuple
14 | stride: pool stride, int or 2-tuple
15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16 | same: override padding and enforce same padding, boolean
17 | """
18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19 | super(MedianPool2d, self).__init__()
20 | self.k = to_2tuple(kernel_size)
21 | self.stride = to_2tuple(stride)
22 | self.padding = to_4tuple(padding) # convert to l, r, t, b
23 | self.same = same
24 |
25 | def _padding(self, x):
26 | if self.same:
27 | ih, iw = x.size()[2:]
28 | if ih % self.stride[0] == 0:
29 | ph = max(self.k[0] - self.stride[0], 0)
30 | else:
31 | ph = max(self.k[0] - (ih % self.stride[0]), 0)
32 | if iw % self.stride[1] == 0:
33 | pw = max(self.k[1] - self.stride[1], 0)
34 | else:
35 | pw = max(self.k[1] - (iw % self.stride[1]), 0)
36 | pl = pw // 2
37 | pr = pw - pl
38 | pt = ph // 2
39 | pb = ph - pt
40 | padding = (pl, pr, pt, pb)
41 | else:
42 | padding = self.padding
43 | return padding
44 |
45 | def forward(self, x):
46 | x = F.pad(x, self._padding(x), mode='reflect')
47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49 | return x
50 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_a_indices.txt:
--------------------------------------------------------------------------------
1 | 6
2 | 11
3 | 13
4 | 15
5 | 17
6 | 22
7 | 23
8 | 27
9 | 30
10 | 37
11 | 39
12 | 42
13 | 47
14 | 50
15 | 57
16 | 70
17 | 71
18 | 76
19 | 79
20 | 89
21 | 90
22 | 94
23 | 96
24 | 97
25 | 99
26 | 105
27 | 107
28 | 108
29 | 110
30 | 113
31 | 124
32 | 125
33 | 130
34 | 132
35 | 143
36 | 144
37 | 150
38 | 151
39 | 207
40 | 234
41 | 235
42 | 254
43 | 277
44 | 283
45 | 287
46 | 291
47 | 295
48 | 298
49 | 301
50 | 306
51 | 307
52 | 308
53 | 309
54 | 310
55 | 311
56 | 313
57 | 314
58 | 315
59 | 317
60 | 319
61 | 323
62 | 324
63 | 326
64 | 327
65 | 330
66 | 334
67 | 335
68 | 336
69 | 347
70 | 361
71 | 363
72 | 372
73 | 378
74 | 386
75 | 397
76 | 400
77 | 401
78 | 402
79 | 404
80 | 407
81 | 411
82 | 416
83 | 417
84 | 420
85 | 425
86 | 428
87 | 430
88 | 437
89 | 438
90 | 445
91 | 456
92 | 457
93 | 461
94 | 462
95 | 470
96 | 472
97 | 483
98 | 486
99 | 488
100 | 492
101 | 496
102 | 514
103 | 516
104 | 528
105 | 530
106 | 539
107 | 542
108 | 543
109 | 549
110 | 552
111 | 557
112 | 561
113 | 562
114 | 569
115 | 572
116 | 573
117 | 575
118 | 579
119 | 589
120 | 606
121 | 607
122 | 609
123 | 614
124 | 626
125 | 627
126 | 640
127 | 641
128 | 642
129 | 643
130 | 658
131 | 668
132 | 677
133 | 682
134 | 684
135 | 687
136 | 701
137 | 704
138 | 719
139 | 736
140 | 746
141 | 749
142 | 752
143 | 758
144 | 763
145 | 765
146 | 768
147 | 773
148 | 774
149 | 776
150 | 779
151 | 780
152 | 786
153 | 792
154 | 797
155 | 802
156 | 803
157 | 804
158 | 813
159 | 815
160 | 820
161 | 823
162 | 831
163 | 833
164 | 835
165 | 839
166 | 845
167 | 847
168 | 850
169 | 859
170 | 862
171 | 870
172 | 879
173 | 880
174 | 888
175 | 890
176 | 897
177 | 900
178 | 907
179 | 913
180 | 924
181 | 932
182 | 933
183 | 934
184 | 937
185 | 943
186 | 945
187 | 947
188 | 951
189 | 954
190 | 956
191 | 957
192 | 959
193 | 971
194 | 972
195 | 980
196 | 981
197 | 984
198 | 986
199 | 987
200 | 988
201 |
--------------------------------------------------------------------------------
/timm/layers/patch_dropout.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PatchDropout(nn.Module):
8 | """
9 | https://arxiv.org/abs/2212.00794
10 | """
11 | return_indices: torch.jit.Final[bool]
12 |
13 | def __init__(
14 | self,
15 | prob: float = 0.5,
16 | num_prefix_tokens: int = 1,
17 | ordered: bool = False,
18 | return_indices: bool = False,
19 | ):
20 | super().__init__()
21 | assert 0 <= prob < 1.
22 | self.prob = prob
23 | self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
24 | self.ordered = ordered
25 | self.return_indices = return_indices
26 |
27 | def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
28 | if not self.training or self.prob == 0.:
29 | if self.return_indices:
30 | return x, None
31 | return x
32 |
33 | if self.num_prefix_tokens:
34 | prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
35 | else:
36 | prefix_tokens = None
37 |
38 | B = x.shape[0]
39 | L = x.shape[1]
40 | num_keep = max(1, int(L * (1. - self.prob)))
41 | keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
42 | if self.ordered:
43 | # NOTE does not need to maintain patch order in typical transformer use,
44 | # but possibly useful for debug / visualization
45 | keep_indices = keep_indices.sort(dim=-1)[0]
46 | x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
47 |
48 | if prefix_tokens is not None:
49 | x = torch.cat((prefix_tokens, x), dim=1)
50 |
51 | if self.return_indices:
52 | return x, keep_indices
53 | return x
54 |
--------------------------------------------------------------------------------
/timm/data/real_labels.py:
--------------------------------------------------------------------------------
1 | """ Real labels evaluator for ImageNet
2 | Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
3 | Based on Numpy example at https://github.com/google-research/reassessed-imagenet
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import os
8 | import json
9 | import numpy as np
10 | import pkgutil
11 |
12 |
13 | class RealLabelsImagenet:
14 |
15 | def __init__(self, filenames, real_json=None, topk=(1, 5)):
16 | if real_json is not None:
17 | with open(real_json) as real_labels:
18 | real_labels = json.load(real_labels)
19 | else:
20 | real_labels = json.loads(
21 | pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
22 | real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
23 | self.real_labels = real_labels
24 | self.filenames = filenames
25 | assert len(self.filenames) == len(self.real_labels)
26 | self.topk = topk
27 | self.is_correct = {k: [] for k in topk}
28 | self.sample_idx = 0
29 |
30 | def add_result(self, output):
31 | maxk = max(self.topk)
32 | _, pred_batch = output.topk(maxk, 1, True, True)
33 | pred_batch = pred_batch.cpu().numpy()
34 | for pred in pred_batch:
35 | filename = self.filenames[self.sample_idx]
36 | filename = os.path.basename(filename)
37 | if self.real_labels[filename]:
38 | for k in self.topk:
39 | self.is_correct[k].append(
40 | any([p in self.real_labels[filename] for p in pred[:k]]))
41 | self.sample_idx += 1
42 |
43 | def get_accuracy(self, k=None):
44 | if k is None:
45 | return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
46 | else:
47 | return float(np.mean(self.is_correct[k])) * 100
48 |
--------------------------------------------------------------------------------
/timm/scheduler/step_lr.py:
--------------------------------------------------------------------------------
1 | """ Step Scheduler
2 |
3 | Basic step LR schedule with warmup, noise.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import math
8 | import torch
9 |
10 | from .scheduler import Scheduler
11 |
12 |
13 | class StepLRScheduler(Scheduler):
14 | """
15 | """
16 |
17 | def __init__(
18 | self,
19 | optimizer: torch.optim.Optimizer,
20 | decay_t: float,
21 | decay_rate: float = 1.,
22 | warmup_t=0,
23 | warmup_lr_init=0,
24 | warmup_prefix=True,
25 | t_in_epochs=True,
26 | noise_range_t=None,
27 | noise_pct=0.67,
28 | noise_std=1.0,
29 | noise_seed=42,
30 | initialize=True,
31 | ) -> None:
32 | super().__init__(
33 | optimizer,
34 | param_group_field="lr",
35 | t_in_epochs=t_in_epochs,
36 | noise_range_t=noise_range_t,
37 | noise_pct=noise_pct,
38 | noise_std=noise_std,
39 | noise_seed=noise_seed,
40 | initialize=initialize,
41 | )
42 |
43 | self.decay_t = decay_t
44 | self.decay_rate = decay_rate
45 | self.warmup_t = warmup_t
46 | self.warmup_lr_init = warmup_lr_init
47 | self.warmup_prefix = warmup_prefix
48 | if self.warmup_t:
49 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
50 | super().update_groups(self.warmup_lr_init)
51 | else:
52 | self.warmup_steps = [1 for _ in self.base_values]
53 |
54 | def _get_lr(self, t):
55 | if t < self.warmup_t:
56 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
57 | else:
58 | if self.warmup_prefix:
59 | t = t - self.warmup_t
60 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
61 | return lrs
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # PyCharm
101 | .idea
102 |
103 | output/
104 |
105 | # PyTorch weights
106 | *.tar
107 | *.pth
108 | *.pt
109 | *.torch
110 | *.gz
111 | Untitled.ipynb
112 | Testing notebook.ipynb
113 |
114 | # Root dir exclusions
115 | /*.csv
116 | /*.yaml
117 | /*.json
118 | /*.jpg
119 | /*.png
120 | /*.zip
121 | /*.tar.*
--------------------------------------------------------------------------------
/timm/layers/space_to_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpaceToDepth(nn.Module):
6 | bs: torch.jit.Final[int]
7 |
8 | def __init__(self, block_size=4):
9 | super().__init__()
10 | assert block_size == 4
11 | self.bs = block_size
12 |
13 | def forward(self, x):
14 | N, C, H, W = x.size()
15 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
16 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
17 | x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
18 | return x
19 |
20 |
21 | @torch.jit.script
22 | class SpaceToDepthJit:
23 | def __call__(self, x: torch.Tensor):
24 | # assuming hard-coded that block_size==4 for acceleration
25 | N, C, H, W = x.size()
26 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
27 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
28 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
29 | return x
30 |
31 |
32 | class SpaceToDepthModule(nn.Module):
33 | def __init__(self, no_jit=False):
34 | super().__init__()
35 | if not no_jit:
36 | self.op = SpaceToDepthJit()
37 | else:
38 | self.op = SpaceToDepth()
39 |
40 | def forward(self, x):
41 | return self.op(x)
42 |
43 |
44 | class DepthToSpace(nn.Module):
45 |
46 | def __init__(self, block_size):
47 | super().__init__()
48 | self.bs = block_size
49 |
50 | def forward(self, x):
51 | N, C, H, W = x.size()
52 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
53 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
54 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
55 | return x
56 |
--------------------------------------------------------------------------------
/timm/layers/mixed_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Mixed Convolution
2 |
3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 | from .conv2d_same import create_conv2d_pad
12 |
13 |
14 | def _split_channels(num_chan, num_groups):
15 | split = [num_chan // num_groups for _ in range(num_groups)]
16 | split[0] += num_chan - sum(split)
17 | return split
18 |
19 |
20 | class MixedConv2d(nn.ModuleDict):
21 | """ Mixed Grouped Convolution
22 |
23 | Based on MDConv and GroupedConv in MixNet impl:
24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
25 | """
26 | def __init__(self, in_channels, out_channels, kernel_size=3,
27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs):
28 | super(MixedConv2d, self).__init__()
29 |
30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
31 | num_groups = len(kernel_size)
32 | in_splits = _split_channels(in_channels, num_groups)
33 | out_splits = _split_channels(out_channels, num_groups)
34 | self.in_channels = sum(in_splits)
35 | self.out_channels = sum(out_splits)
36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
37 | conv_groups = in_ch if depthwise else 1
38 | # use add_module to keep key space clean
39 | self.add_module(
40 | str(idx),
41 | create_conv2d_pad(
42 | in_ch, out_ch, k, stride=stride,
43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
44 | )
45 | self.splits = in_splits
46 |
47 | def forward(self, x):
48 | x_split = torch.split(x, self.splits, 1)
49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
50 | x = torch.cat(x_out, 1)
51 | return x
52 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """ Setup
2 | """
3 | from setuptools import setup, find_packages
4 | from codecs import open
5 | from os import path
6 |
7 | here = path.abspath(path.dirname(__file__))
8 |
9 | # Get the long description from the README file
10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f:
11 | long_description = f.read()
12 |
13 | exec(open('timm/version.py').read())
14 | setup(
15 | name='timm',
16 | version=__version__,
17 | description='PyTorch Image Models',
18 | long_description=long_description,
19 | long_description_content_type='text/markdown',
20 | url='https://github.com/huggingface/pytorch-image-models',
21 | author='Ross Wightman',
22 | author_email='ross@huggingface.co',
23 | classifiers=[
24 | # How mature is this project? Common values are
25 | # 3 - Alpha
26 | # 4 - Beta
27 | # 5 - Production/Stable
28 | 'Development Status :: 4 - Beta',
29 | 'Intended Audience :: Education',
30 | 'Intended Audience :: Science/Research',
31 | 'License :: OSI Approved :: Apache Software License',
32 | 'Programming Language :: Python :: 3.8',
33 | 'Programming Language :: Python :: 3.9',
34 | 'Programming Language :: Python :: 3.10',
35 | 'Programming Language :: Python :: 3.11',
36 | 'Topic :: Scientific/Engineering',
37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
38 | 'Topic :: Software Development',
39 | 'Topic :: Software Development :: Libraries',
40 | 'Topic :: Software Development :: Libraries :: Python Modules',
41 | ],
42 |
43 | # Note that this is a string of words separated by whitespace, not a list.
44 | keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
45 | packages=find_packages(exclude=['convert', 'tests', 'results']),
46 | include_package_data=True,
47 | install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
48 | python_requires='>=3.7',
49 | )
50 |
51 |
--------------------------------------------------------------------------------
/hfdocs/source/hf_hub.mdx:
--------------------------------------------------------------------------------
1 | # Sharing and Loading Models From the Hugging Face Hub
2 |
3 | The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub.
4 |
5 | In this short guide, we'll see how to:
6 | 1. Share a `timm` model on the Hub
7 | 2. How to load that model back from the Hub
8 |
9 | ## Authenticating
10 |
11 | First, you'll need to make sure you have the `huggingface_hub` package installed.
12 |
13 | ```bash
14 | pip install huggingface_hub
15 | ```
16 |
17 | Then, you'll need to authenticate yourself. You can do this by running the following command:
18 |
19 | ```bash
20 | huggingface-cli login
21 | ```
22 |
23 | Or, if you're using a notebook, you can use the `notebook_login` helper:
24 |
25 | ```py
26 | >>> from huggingface_hub import notebook_login
27 | >>> notebook_login()
28 | ```
29 |
30 | ## Sharing a Model
31 |
32 | ```py
33 | >>> import timm
34 | >>> model = timm.create_model('resnet18', pretrained=True, num_classes=4)
35 | ```
36 |
37 | Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial.
38 |
39 | Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function.
40 |
41 | ```py
42 | >>> model_cfg = dict(labels=['a', 'b', 'c', 'd'])
43 | >>> timm.models.hub.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg)
44 | ```
45 |
46 | Running the above would push the model to `/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code!
47 |
48 | ## Loading a Model
49 |
50 | Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub.
51 |
52 | ```py
53 | >>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True)
54 | ```
55 |
--------------------------------------------------------------------------------
/docs/models/.templates/generate_readmes.py:
--------------------------------------------------------------------------------
1 | """
2 | Run this script to generate the model-index files in `models` from the templates in `.templates/models`.
3 | """
4 |
5 | import argparse
6 | from pathlib import Path
7 |
8 | from jinja2 import Environment, FileSystemLoader
9 |
10 | import modelindex
11 |
12 |
13 | def generate_readmes(templates_path: Path, dest_path: Path):
14 | """Add the code snippet template to the readmes"""
15 | readme_templates_path = templates_path / "models"
16 | code_template_path = templates_path / "code_snippets.md"
17 |
18 | env = Environment(
19 | loader=FileSystemLoader([readme_templates_path, readme_templates_path.parent]),
20 | )
21 |
22 | for readme in readme_templates_path.iterdir():
23 | if readme.suffix == ".md":
24 | template = env.get_template(readme.name)
25 |
26 | # get the first model_name for this model family
27 | mi = modelindex.load(str(readme))
28 | model_name = mi.models[0].name
29 |
30 | full_content = template.render(model_name=model_name)
31 |
32 | # generate full_readme
33 | with open(dest_path / readme.name, "w") as f:
34 | f.write(full_content)
35 |
36 |
37 | def main():
38 | parser = argparse.ArgumentParser(description="Model index generation config")
39 | parser.add_argument(
40 | "-t",
41 | "--templates",
42 | default=Path(__file__).parent / ".templates",
43 | type=str,
44 | help="Location of the markdown templates",
45 | )
46 | parser.add_argument(
47 | "-d",
48 | "--dest",
49 | default=Path(__file__).parent / "models",
50 | type=str,
51 | help="Destination folder that contains the generated model-index files.",
52 | )
53 | args = parser.parse_args()
54 | templates_path = Path(args.templates)
55 | dest_readmes_path = Path(args.dest)
56 |
57 | generate_readmes(
58 | templates_path,
59 | dest_readmes_path,
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/spnasnet.md:
--------------------------------------------------------------------------------
1 | # SPNASNet
2 |
3 | **Single-Path NAS** is a novel differentiable NAS method for designing hardware-efficient ConvNets in less than 4 hours.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{stamoulis2019singlepath,
15 | title={Single-Path NAS: Designing Hardware-Efficient ConvNets in less than 4 Hours},
16 | author={Dimitrios Stamoulis and Ruizhou Ding and Di Wang and Dimitrios Lymberopoulos and Bodhi Priyantha and Jie Liu and Diana Marculescu},
17 | year={2019},
18 | eprint={1904.02877},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.LG}
21 | }
22 | ```
23 |
24 |
63 |
--------------------------------------------------------------------------------
/tests/test_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from timm.layers import create_act_layer, set_layer_config
5 |
6 | import importlib
7 | import os
8 |
9 | torch_backend = os.environ.get('TORCH_BACKEND')
10 | if torch_backend is not None:
11 | importlib.import_module(torch_backend)
12 | torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
13 |
14 | class MLP(nn.Module):
15 | def __init__(self, act_layer="relu", inplace=True):
16 | super(MLP, self).__init__()
17 | self.fc1 = nn.Linear(1000, 100)
18 | self.act = create_act_layer(act_layer, inplace=inplace)
19 | self.fc2 = nn.Linear(100, 10)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.act(x)
24 | x = self.fc2(x)
25 | return x
26 |
27 |
28 | def _run_act_layer_grad(act_type, inplace=True):
29 | x = torch.rand(10, 1000) * 10
30 | m = MLP(act_layer=act_type, inplace=inplace)
31 |
32 | def _run(x, act_layer=''):
33 | if act_layer:
34 | # replace act layer if set
35 | m.act = create_act_layer(act_layer, inplace=inplace)
36 | out = m(x)
37 | l = (out - 0).pow(2).sum()
38 | return l
39 |
40 | x = x.to(device=torch_device)
41 | m.to(device=torch_device)
42 |
43 | out_me = _run(x)
44 |
45 | with set_layer_config(scriptable=True):
46 | out_jit = _run(x, act_type)
47 |
48 | assert torch.isclose(out_jit, out_me)
49 |
50 | with set_layer_config(no_jit=True):
51 | out_basic = _run(x, act_type)
52 |
53 | assert torch.isclose(out_basic, out_jit)
54 |
55 |
56 | def test_swish_grad():
57 | for _ in range(100):
58 | _run_act_layer_grad('swish')
59 |
60 |
61 | def test_mish_grad():
62 | for _ in range(100):
63 | _run_act_layer_grad('mish')
64 |
65 |
66 | def test_hard_sigmoid_grad():
67 | for _ in range(100):
68 | _run_act_layer_grad('hard_sigmoid', inplace=None)
69 |
70 |
71 | def test_hard_swish_grad():
72 | for _ in range(100):
73 | _run_act_layer_grad('hard_swish')
74 |
75 |
76 | def test_hard_mish_grad():
77 | for _ in range(100):
78 | _run_act_layer_grad('hard_mish')
79 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/gloun-senet.md:
--------------------------------------------------------------------------------
1 | # (Gluon) SENet
2 |
3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration.
4 |
5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{hu2019squeezeandexcitation,
17 | title={Squeeze-and-Excitation Networks},
18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu},
19 | year={2019},
20 | eprint={1709.01507},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
64 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.batchnorm import BatchNorm2d
2 | from torchvision.ops.misc import FrozenBatchNorm2d
3 |
4 | import timm
5 | from timm.utils.model import freeze, unfreeze
6 |
7 |
8 | def test_freeze_unfreeze():
9 | model = timm.create_model('resnet18')
10 |
11 | # Freeze all
12 | freeze(model)
13 | # Check top level module
14 | assert model.fc.weight.requires_grad == False
15 | # Check submodule
16 | assert model.layer1[0].conv1.weight.requires_grad == False
17 | # Check BN
18 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
19 |
20 | # Unfreeze all
21 | unfreeze(model)
22 | # Check top level module
23 | assert model.fc.weight.requires_grad == True
24 | # Check submodule
25 | assert model.layer1[0].conv1.weight.requires_grad == True
26 | # Check BN
27 | assert isinstance(model.layer1[0].bn1, BatchNorm2d)
28 |
29 | # Freeze some
30 | freeze(model, ['layer1', 'layer2.0'])
31 | # Check frozen
32 | assert model.layer1[0].conv1.weight.requires_grad == False
33 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
34 | assert model.layer2[0].conv1.weight.requires_grad == False
35 | # Check not frozen
36 | assert model.layer3[0].conv1.weight.requires_grad == True
37 | assert isinstance(model.layer3[0].bn1, BatchNorm2d)
38 | assert model.layer2[1].conv1.weight.requires_grad == True
39 |
40 | # Unfreeze some
41 | unfreeze(model, ['layer1', 'layer2.0'])
42 | # Check not frozen
43 | assert model.layer1[0].conv1.weight.requires_grad == True
44 | assert isinstance(model.layer1[0].bn1, BatchNorm2d)
45 | assert model.layer2[0].conv1.weight.requires_grad == True
46 |
47 | # Freeze/unfreeze BN
48 | # From root
49 | freeze(model, ['layer1.0.bn1'])
50 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
51 | unfreeze(model, ['layer1.0.bn1'])
52 | assert isinstance(model.layer1[0].bn1, BatchNorm2d)
53 | # From direct parent
54 | freeze(model.layer1[0], ['bn1'])
55 | assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
56 | unfreeze(model.layer1[0], ['bn1'])
57 | assert isinstance(model.layer1[0].bn1, BatchNorm2d)
--------------------------------------------------------------------------------
/timm/scheduler/multistep_lr.py:
--------------------------------------------------------------------------------
1 | """ MultiStep LR Scheduler
2 |
3 | Basic multi step LR schedule with warmup, noise.
4 | """
5 | import torch
6 | import bisect
7 | from timm.scheduler.scheduler import Scheduler
8 | from typing import List
9 |
10 | class MultiStepLRScheduler(Scheduler):
11 | """
12 | """
13 |
14 | def __init__(
15 | self,
16 | optimizer: torch.optim.Optimizer,
17 | decay_t: List[int],
18 | decay_rate: float = 1.,
19 | warmup_t=0,
20 | warmup_lr_init=0,
21 | warmup_prefix=True,
22 | t_in_epochs=True,
23 | noise_range_t=None,
24 | noise_pct=0.67,
25 | noise_std=1.0,
26 | noise_seed=42,
27 | initialize=True,
28 | ) -> None:
29 | super().__init__(
30 | optimizer,
31 | param_group_field="lr",
32 | t_in_epochs=t_in_epochs,
33 | noise_range_t=noise_range_t,
34 | noise_pct=noise_pct,
35 | noise_std=noise_std,
36 | noise_seed=noise_seed,
37 | initialize=initialize,
38 | )
39 |
40 | self.decay_t = decay_t
41 | self.decay_rate = decay_rate
42 | self.warmup_t = warmup_t
43 | self.warmup_lr_init = warmup_lr_init
44 | self.warmup_prefix = warmup_prefix
45 | if self.warmup_t:
46 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
47 | super().update_groups(self.warmup_lr_init)
48 | else:
49 | self.warmup_steps = [1 for _ in self.base_values]
50 |
51 | def get_curr_decay_steps(self, t):
52 | # find where in the array t goes,
53 | # assumes self.decay_t is sorted
54 | return bisect.bisect_right(self.decay_t, t + 1)
55 |
56 | def _get_lr(self, t):
57 | if t < self.warmup_t:
58 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
59 | else:
60 | if self.warmup_prefix:
61 | t = t - self.warmup_t
62 | lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
63 | return lrs
64 |
--------------------------------------------------------------------------------
/timm/layers/test_time_pool.py:
--------------------------------------------------------------------------------
1 | """ Test Time Pooling (Average-Max Pool)
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | import logging
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | class TestTimePoolHead(nn.Module):
17 | def __init__(self, base, original_pool=7):
18 | super(TestTimePoolHead, self).__init__()
19 | self.base = base
20 | self.original_pool = original_pool
21 | base_fc = self.base.get_classifier()
22 | if isinstance(base_fc, nn.Conv2d):
23 | self.fc = base_fc
24 | else:
25 | self.fc = nn.Conv2d(
26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
29 | self.base.reset_classifier(0) # delete original fc layer
30 |
31 | def forward(self, x):
32 | x = self.base.forward_features(x)
33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
34 | x = self.fc(x)
35 | x = adaptive_avgmax_pool2d(x, 1)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 | def apply_test_time_pool(model, config, use_test_size=False):
40 | test_time_pool = False
41 | if not hasattr(model, 'default_cfg') or not model.default_cfg:
42 | return model, False
43 | if use_test_size and 'test_input_size' in model.default_cfg:
44 | df_input_size = model.default_cfg['test_input_size']
45 | else:
46 | df_input_size = model.default_cfg['input_size']
47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
49 | (str(config['input_size'][-2:]), str(df_input_size[-2:])))
50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
51 | test_time_pool = True
52 | return model, test_time_pool
53 |
--------------------------------------------------------------------------------
/timm/loss/binary_cross_entropy.py:
--------------------------------------------------------------------------------
1 | """ Binary Cross Entropy w/ a few extras
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | from typing import Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BinaryCrossEntropy(nn.Module):
13 | """ BCE with optional one-hot from dense targets, label smoothing, thresholding
14 | NOTE for experiments comparing CE to BCE /w label smoothing, may remove
15 | """
16 | def __init__(
17 | self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
18 | reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
19 | super(BinaryCrossEntropy, self).__init__()
20 | assert 0. <= smoothing < 1.0
21 | self.smoothing = smoothing
22 | self.target_threshold = target_threshold
23 | self.reduction = reduction
24 | self.register_buffer('weight', weight)
25 | self.register_buffer('pos_weight', pos_weight)
26 |
27 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
28 | assert x.shape[0] == target.shape[0]
29 | if target.shape != x.shape:
30 | # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
31 | num_classes = x.shape[-1]
32 | # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
33 | off_value = self.smoothing / num_classes
34 | on_value = 1. - self.smoothing + off_value
35 | target = target.long().view(-1, 1)
36 | target = torch.full(
37 | (target.size()[0], num_classes),
38 | off_value,
39 | device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
40 | if self.target_threshold is not None:
41 | # Make target 0, or 1 if threshold set
42 | target = target.gt(self.target_threshold).to(dtype=target.dtype)
43 | return F.binary_cross_entropy_with_logits(
44 | x, target,
45 | self.weight,
46 | pos_weight=self.pos_weight,
47 | reduction=self.reduction)
48 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/nasnet.md:
--------------------------------------------------------------------------------
1 | # NASNet
2 |
3 | **NASNet** is a type of convolutional neural network discovered through neural architecture search. The building blocks consist of normal and reduction cells.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{zoph2018learning,
15 | title={Learning Transferable Architectures for Scalable Image Recognition},
16 | author={Barret Zoph and Vijay Vasudevan and Jonathon Shlens and Quoc V. Le},
17 | year={2018},
18 | eprint={1707.07012},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
71 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/gloun-xception.md:
--------------------------------------------------------------------------------
1 | # (Gluon) Xception
2 |
3 | **Xception** is a convolutional neural network architecture that relies solely on [depthwise separable convolution](https://paperswithcode.com/method/depthwise-separable-convolution) layers.
4 |
5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{chollet2017xception,
17 | title={Xception: Deep Learning with Depthwise Separable Convolutions},
18 | author={François Chollet},
19 | year={2017},
20 | eprint={1610.02357},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
67 |
--------------------------------------------------------------------------------
/timm/utils/cuda.py:
--------------------------------------------------------------------------------
1 | """ CUDA / AMP utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 |
7 | try:
8 | from apex import amp
9 | has_apex = True
10 | except ImportError:
11 | amp = None
12 | has_apex = False
13 |
14 | from .clip_grad import dispatch_clip_grad
15 |
16 |
17 | class ApexScaler:
18 | state_dict_key = "amp"
19 |
20 | def __call__(
21 | self,
22 | loss,
23 | optimizer,
24 | clip_grad=None,
25 | clip_mode='norm',
26 | parameters=None,
27 | create_graph=False,
28 | need_update=True,
29 | ):
30 | with amp.scale_loss(loss, optimizer) as scaled_loss:
31 | scaled_loss.backward(create_graph=create_graph)
32 | if need_update:
33 | if clip_grad is not None:
34 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
35 | optimizer.step()
36 |
37 | def state_dict(self):
38 | if 'state_dict' in amp.__dict__:
39 | return amp.state_dict()
40 |
41 | def load_state_dict(self, state_dict):
42 | if 'load_state_dict' in amp.__dict__:
43 | amp.load_state_dict(state_dict)
44 |
45 |
46 | class NativeScaler:
47 | state_dict_key = "amp_scaler"
48 |
49 | def __init__(self):
50 | self._scaler = torch.cuda.amp.GradScaler()
51 |
52 | def __call__(
53 | self,
54 | loss,
55 | optimizer,
56 | clip_grad=None,
57 | clip_mode='norm',
58 | parameters=None,
59 | create_graph=False,
60 | need_update=True,
61 | ):
62 | self._scaler.scale(loss).backward(create_graph=create_graph)
63 | if need_update:
64 | if clip_grad is not None:
65 | assert parameters is not None
66 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
67 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
68 | self._scaler.step(optimizer)
69 | self._scaler.update()
70 |
71 | def state_dict(self):
72 | return self._scaler.state_dict()
73 |
74 | def load_state_dict(self, state_dict):
75 | self._scaler.load_state_dict(state_dict)
76 |
--------------------------------------------------------------------------------
/hfdocs/source/installation.mdx:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**.
4 |
5 | ## Virtual Environment
6 |
7 | You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts.
8 |
9 | 1. Create and navigate to your project directory:
10 |
11 | ```bash
12 | mkdir ~/my-project
13 | cd ~/my-project
14 | ```
15 |
16 | 2. Start a virtual environment inside your directory:
17 |
18 | ```bash
19 | python -m venv .env
20 | ```
21 |
22 | 3. Activate and deactivate the virtual environment with the following commands:
23 |
24 | ```bash
25 | # Activate the virtual environment
26 | source .env/bin/activate
27 |
28 | # Deactivate the virtual environment
29 | source .env/bin/deactivate
30 | ```
31 | `
32 | Once you've created your virtual environment, you can install `timm` in it.
33 |
34 | ## Using pip
35 |
36 | The most straightforward way to install `timm` is with pip:
37 |
38 | ```bash
39 | pip install timm
40 | ```
41 |
42 | Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version:
43 |
44 | ```bash
45 | pip install git+https://github.com/rwightman/pytorch-image-models.git
46 | ```
47 |
48 | Run the following command to check if `timm` has been properly installed:
49 |
50 | ```bash
51 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
52 | ```
53 |
54 | This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output:
55 |
56 | ```python
57 | ['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384']
58 | ```
59 |
60 | ## From Source
61 |
62 | Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands:
63 |
64 | ```bash
65 | git clone https://github.com/rwightman/pytorch-image-models.git
66 | cd timm
67 | pip install -e .
68 | ```
69 |
70 | Again, you can check if `timm` was properly installed with the following command:
71 |
72 | ```bash
73 | python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
74 | ```
75 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/legacy-senet.md:
--------------------------------------------------------------------------------
1 | # (Legacy) SENet
2 |
3 | A **SENet** is a convolutional neural network architecture that employs [squeeze-and-excitation blocks](https://paperswithcode.com/method/squeeze-and-excitation-block) to enable the network to perform dynamic channel-wise feature recalibration.
4 |
5 | The weights from this model were ported from Gluon.
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{hu2019squeezeandexcitation,
17 | title={Squeeze-and-Excitation Networks},
18 | author={Jie Hu and Li Shen and Samuel Albanie and Gang Sun and Enhua Wu},
19 | year={2019},
20 | eprint={1709.01507},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
75 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/inception-v4.md:
--------------------------------------------------------------------------------
1 | # Inception v4
2 |
3 | **Inception-v4** is a convolutional neural network architecture that builds on previous iterations of the Inception family by simplifying the architecture and using more inception modules than [Inception-v3](https://paperswithcode.com/method/inception-v3).
4 | {% include 'code_snippets.md' %}
5 |
6 | ## How do I train this model?
7 |
8 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
9 |
10 | ## Citation
11 |
12 | ```BibTeX
13 | @misc{szegedy2016inceptionv4,
14 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning},
15 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi},
16 | year={2016},
17 | eprint={1602.07261},
18 | archivePrefix={arXiv},
19 | primaryClass={cs.CV}
20 | }
21 | ```
22 |
23 |
72 |
--------------------------------------------------------------------------------
/timm/utils/jit.py:
--------------------------------------------------------------------------------
1 | """ JIT scripting/tracing utils
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import os
6 |
7 | import torch
8 |
9 |
10 | def set_jit_legacy():
11 | """ Set JIT executor to legacy w/ support for op fusion
12 | This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
13 | in the JIT exectutor. These API are not supported so could change.
14 | """
15 | #
16 | assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
17 | torch._C._jit_set_profiling_executor(False)
18 | torch._C._jit_set_profiling_mode(False)
19 | torch._C._jit_override_can_fuse_on_gpu(True)
20 | #torch._C._jit_set_texpr_fuser_enabled(True)
21 |
22 |
23 | def set_jit_fuser(fuser):
24 | if fuser == "te":
25 | # default fuser should be == 'te'
26 | torch._C._jit_set_profiling_executor(True)
27 | torch._C._jit_set_profiling_mode(True)
28 | torch._C._jit_override_can_fuse_on_cpu(False)
29 | torch._C._jit_override_can_fuse_on_gpu(True)
30 | torch._C._jit_set_texpr_fuser_enabled(True)
31 | try:
32 | torch._C._jit_set_nvfuser_enabled(False)
33 | except Exception:
34 | pass
35 | elif fuser == "old" or fuser == "legacy":
36 | torch._C._jit_set_profiling_executor(False)
37 | torch._C._jit_set_profiling_mode(False)
38 | torch._C._jit_override_can_fuse_on_gpu(True)
39 | torch._C._jit_set_texpr_fuser_enabled(False)
40 | try:
41 | torch._C._jit_set_nvfuser_enabled(False)
42 | except Exception:
43 | pass
44 | elif fuser == "nvfuser" or fuser == "nvf":
45 | os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
46 | #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
47 | #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
48 | torch._C._jit_set_texpr_fuser_enabled(False)
49 | torch._C._jit_set_profiling_executor(True)
50 | torch._C._jit_set_profiling_mode(True)
51 | torch._C._jit_can_fuse_on_cpu()
52 | torch._C._jit_can_fuse_on_gpu()
53 | torch._C._jit_override_can_fuse_on_cpu(False)
54 | torch._C._jit_override_can_fuse_on_gpu(False)
55 | torch._C._jit_set_nvfuser_guard_mode(True)
56 | torch._C._jit_set_nvfuser_enabled(True)
57 | else:
58 | assert False, f"Invalid jit fuser ({fuser})"
59 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/skresnext.md:
--------------------------------------------------------------------------------
1 | # SK-ResNeXt
2 |
3 | **SK ResNeXt** is a variant of a [ResNeXt](https://www.paperswithcode.com/method/resnext) that employs a [Selective Kernel](https://paperswithcode.com/method/selective-kernel) unit. In general, all the large kernel convolutions in the original bottleneck blocks in ResNext are replaced by the proposed [SK convolutions](https://paperswithcode.com/method/selective-kernel-convolution), enabling the network to choose appropriate receptive field sizes in an adaptive manner.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{li2019selective,
15 | title={Selective Kernel Networks},
16 | author={Xiang Li and Wenhai Wang and Xiaolin Hu and Jian Yang},
17 | year={2019},
18 | eprint={1903.06586},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
71 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Getting Started
2 |
3 | ## Welcome
4 |
5 | Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`.
6 |
7 | For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora).
8 |
9 | ## Install
10 |
11 | The library can be installed with pip:
12 |
13 | ```
14 | pip install timm
15 | ```
16 |
17 | I update the PyPi (pip) packages when I'm confident there are no significant model regressions from previous releases. If you want to pip install the bleeding edge from GitHub, use:
18 | ```
19 | pip install git+https://github.com/rwightman/pytorch-image-models.git
20 | ```
21 |
22 | !!! info "Conda Environment"
23 | All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10
24 |
25 | Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
26 |
27 | PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code.
28 |
29 | I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
30 | ```
31 | conda create -n torch-env
32 | conda activate torch-env
33 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
34 | conda install pyyaml
35 | ```
36 |
37 | ## Load a Pretrained Model
38 |
39 | Pretrained models can be loaded using `timm.create_model`
40 |
41 | ```python
42 | import timm
43 |
44 | m = timm.create_model('mobilenetv3_large_100', pretrained=True)
45 | m.eval()
46 | ```
47 |
48 | ## List Models with Pretrained Weights
49 | ```python
50 | import timm
51 | from pprint import pprint
52 | model_names = timm.list_models(pretrained=True)
53 | pprint(model_names)
54 | >>> ['adv_inception_v3',
55 | 'cspdarknet53',
56 | 'cspresnext50',
57 | 'densenet121',
58 | 'densenet161',
59 | 'densenet169',
60 | 'densenet201',
61 | 'densenetblur121d',
62 | 'dla34',
63 | 'dla46_c',
64 | ...
65 | ]
66 | ```
67 |
68 | ## List Model Architectures by Wildcard
69 | ```python
70 | import timm
71 | from pprint import pprint
72 | model_names = timm.list_models('*resne*t*')
73 | pprint(model_names)
74 | >>> ['cspresnet50',
75 | 'cspresnet50d',
76 | 'cspresnet50w',
77 | 'cspresnext50',
78 | ...
79 | ]
80 | ```
81 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/pnasnet.md:
--------------------------------------------------------------------------------
1 | # PNASNet
2 |
3 | **Progressive Neural Architecture Search**, or **PNAS**, is a method for learning the structure of convolutional neural networks (CNNs). It uses a sequential model-based optimization (SMBO) strategy, where we search the space of cell structures, starting with simple (shallow) models and progressing to complex ones, pruning out unpromising structures as we go.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{liu2018progressive,
15 | title={Progressive Neural Architecture Search},
16 | author={Chenxi Liu and Barret Zoph and Maxim Neumann and Jonathon Shlens and Wei Hua and Li-Jia Li and Li Fei-Fei and Alan Yuille and Jonathan Huang and Kevin Murphy},
17 | year={2018},
18 | eprint={1712.00559},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
72 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/inception-resnet-v2.md:
--------------------------------------------------------------------------------
1 | # Inception ResNet v2
2 |
3 | **Inception-ResNet-v2** is a convolutional neural architecture that builds on the Inception family of architectures but incorporates [residual connections](https://paperswithcode.com/method/residual-connection) (replacing the filter concatenation stage of the Inception architecture).
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{szegedy2016inceptionv4,
15 | title={Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning},
16 | author={Christian Szegedy and Sergey Ioffe and Vincent Vanhoucke and Alex Alemi},
17 | year={2016},
18 | eprint={1602.07261},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
73 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Python tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | env:
10 | OMP_NUM_THREADS: 2
11 | MKL_NUM_THREADS: 2
12 |
13 | jobs:
14 | test:
15 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
16 | strategy:
17 | matrix:
18 | os: [ubuntu-latest]
19 | python: ['3.10', '3.11']
20 | torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}]
21 | testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
22 | exclude:
23 | - python: '3.11'
24 | torch: {base: '1.13.0', vision: '0.14.0'}
25 | runs-on: ${{ matrix.os }}
26 |
27 | steps:
28 | - uses: actions/checkout@v2
29 | - name: Set up Python ${{ matrix.python }}
30 | uses: actions/setup-python@v1
31 | with:
32 | python-version: ${{ matrix.python }}
33 | - name: Install testing dependencies
34 | run: |
35 | python -m pip install --upgrade pip
36 | pip install -r requirements-dev.txt
37 | - name: Install torch on mac
38 | if: startsWith(matrix.os, 'macOS')
39 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
40 | - name: Install torch on Windows
41 | if: startsWith(matrix.os, 'windows')
42 | run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
43 | - name: Install torch on ubuntu
44 | if: startsWith(matrix.os, 'ubuntu')
45 | run: |
46 | sudo sed -i 's/azure\.//' /etc/apt/sources.list
47 | sudo apt update
48 | sudo apt install -y google-perftools
49 | pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
50 | - name: Install requirements
51 | run: |
52 | pip install -r requirements.txt
53 | - name: Run tests on Windows
54 | if: startsWith(matrix.os, 'windows')
55 | env:
56 | PYTHONDONTWRITEBYTECODE: 1
57 | run: |
58 | pytest -vv tests
59 | - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
60 | if: ${{ !startsWith(matrix.os, 'windows') }}
61 | env:
62 | LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
63 | PYTHONDONTWRITEBYTECODE: 1
64 | run: |
65 | pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests
66 |
--------------------------------------------------------------------------------
/timm/optim/sgdp.py:
--------------------------------------------------------------------------------
1 | """
2 | SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3 |
4 | Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5 | Code: https://github.com/clovaai/AdamP
6 |
7 | Copyright (c) 2020-present NAVER Corp.
8 | MIT license
9 | """
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.optim.optimizer import Optimizer, required
14 | import math
15 |
16 | from .adamp import projection
17 |
18 |
19 | class SGDP(Optimizer):
20 | def __init__(self, params, lr=required, momentum=0, dampening=0,
21 | weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
22 | defaults = dict(
23 | lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
24 | nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
25 | super(SGDP, self).__init__(params, defaults)
26 |
27 | @torch.no_grad()
28 | def step(self, closure=None):
29 | loss = None
30 | if closure is not None:
31 | with torch.enable_grad():
32 | loss = closure()
33 |
34 | for group in self.param_groups:
35 | weight_decay = group['weight_decay']
36 | momentum = group['momentum']
37 | dampening = group['dampening']
38 | nesterov = group['nesterov']
39 |
40 | for p in group['params']:
41 | if p.grad is None:
42 | continue
43 | grad = p.grad
44 | state = self.state[p]
45 |
46 | # State initialization
47 | if len(state) == 0:
48 | state['momentum'] = torch.zeros_like(p)
49 |
50 | # SGD
51 | buf = state['momentum']
52 | buf.mul_(momentum).add_(grad, alpha=1. - dampening)
53 | if nesterov:
54 | d_p = grad + momentum * buf
55 | else:
56 | d_p = buf
57 |
58 | # Projection
59 | wd_ratio = 1.
60 | if len(p.shape) > 1:
61 | d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
62 |
63 | # Weight decay
64 | if weight_decay != 0:
65 | p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
66 |
67 | # Step
68 | p.add_(d_p, alpha=-group['lr'])
69 |
70 | return loss
71 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/csp-resnet.md:
--------------------------------------------------------------------------------
1 | # CSP-ResNet
2 |
3 | **CSPResNet** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNet](https://paperswithcode.com/method/resnet). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{wang2019cspnet,
15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN},
16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh},
17 | year={2019},
18 | eprint={1911.11929},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
77 |
--------------------------------------------------------------------------------
/timm/data/dataset_info.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional, Union
3 |
4 |
5 | class DatasetInfo(ABC):
6 |
7 | def __init__(self):
8 | pass
9 |
10 | @abstractmethod
11 | def num_classes(self):
12 | pass
13 |
14 | @abstractmethod
15 | def label_names(self):
16 | pass
17 |
18 | @abstractmethod
19 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
20 | pass
21 |
22 | @abstractmethod
23 | def index_to_label_name(self, index) -> str:
24 | pass
25 |
26 | @abstractmethod
27 | def index_to_description(self, index: int, detailed: bool = False) -> str:
28 | pass
29 |
30 | @abstractmethod
31 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
32 | pass
33 |
34 |
35 | class CustomDatasetInfo(DatasetInfo):
36 | """ DatasetInfo that wraps passed values for custom datasets."""
37 |
38 | def __init__(
39 | self,
40 | label_names: Union[List[str], Dict[int, str]],
41 | label_descriptions: Optional[Dict[str, str]] = None
42 | ):
43 | super().__init__()
44 | assert len(label_names) > 0
45 | self._label_names = label_names # label index => label name mapping
46 | self._label_descriptions = label_descriptions # label name => label description mapping
47 | if self._label_descriptions is not None:
48 | # validate descriptions (label names required)
49 | assert isinstance(self._label_descriptions, dict)
50 | for n in self._label_names:
51 | assert n in self._label_descriptions
52 |
53 | def num_classes(self):
54 | return len(self._label_names)
55 |
56 | def label_names(self):
57 | return self._label_names
58 |
59 | def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
60 | return self._label_descriptions
61 |
62 | def label_name_to_description(self, label: str, detailed: bool = False) -> str:
63 | if self._label_descriptions:
64 | return self._label_descriptions[label]
65 | return label # return label name itself if a descriptions is not present
66 |
67 | def index_to_label_name(self, index) -> str:
68 | assert 0 <= index < len(self._label_names)
69 | return self._label_names[index]
70 |
71 | def index_to_description(self, index: int, detailed: bool = False) -> str:
72 | label = self.index_to_label_name(index)
73 | return self.label_name_to_description(label, detailed=detailed)
74 |
--------------------------------------------------------------------------------
/timm/layers/interpolate.py:
--------------------------------------------------------------------------------
1 | """ Interpolation helpers for timm layers
2 |
3 | RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
4 | Copyright Shane Barratt, Apache 2.0 license
5 | """
6 | import torch
7 | from itertools import product
8 |
9 |
10 | class RegularGridInterpolator:
11 | """ Interpolate data defined on a rectilinear grid with even or uneven spacing.
12 | Produces similar results to scipy RegularGridInterpolator or interp2d
13 | in 'linear' mode.
14 |
15 | Taken from https://github.com/sbarratt/torch_interpolations
16 | """
17 |
18 | def __init__(self, points, values):
19 | self.points = points
20 | self.values = values
21 |
22 | assert isinstance(self.points, tuple) or isinstance(self.points, list)
23 | assert isinstance(self.values, torch.Tensor)
24 |
25 | self.ms = list(self.values.shape)
26 | self.n = len(self.points)
27 |
28 | assert len(self.ms) == self.n
29 |
30 | for i, p in enumerate(self.points):
31 | assert isinstance(p, torch.Tensor)
32 | assert p.shape[0] == self.values.shape[i]
33 |
34 | def __call__(self, points_to_interp):
35 | assert self.points is not None
36 | assert self.values is not None
37 |
38 | assert len(points_to_interp) == len(self.points)
39 | K = points_to_interp[0].shape[0]
40 | for x in points_to_interp:
41 | assert x.shape[0] == K
42 |
43 | idxs = []
44 | dists = []
45 | overalls = []
46 | for p, x in zip(self.points, points_to_interp):
47 | idx_right = torch.bucketize(x, p)
48 | idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
49 | idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
50 | dist_left = x - p[idx_left]
51 | dist_right = p[idx_right] - x
52 | dist_left[dist_left < 0] = 0.
53 | dist_right[dist_right < 0] = 0.
54 | both_zero = (dist_left == 0) & (dist_right == 0)
55 | dist_left[both_zero] = dist_right[both_zero] = 1.
56 |
57 | idxs.append((idx_left, idx_right))
58 | dists.append((dist_left, dist_right))
59 | overalls.append(dist_left + dist_right)
60 |
61 | numerator = 0.
62 | for indexer in product([0, 1], repeat=self.n):
63 | as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
64 | bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
65 | numerator += self.values[as_s] * \
66 | torch.prod(torch.stack(bs_s), dim=0)
67 | denominator = torch.prod(torch.stack(overalls), dim=0)
68 | return numerator / denominator
69 |
--------------------------------------------------------------------------------
/docs/models/.templates/code_snippets.md:
--------------------------------------------------------------------------------
1 | ## How do I use this model on an image?
2 | To load a pretrained model:
3 |
4 | ```python
5 | import timm
6 | model = timm.create_model('{{ model_name }}', pretrained=True)
7 | model.eval()
8 | ```
9 |
10 | To load and preprocess the image:
11 | ```python
12 | import urllib
13 | from PIL import Image
14 | from timm.data import resolve_data_config
15 | from timm.data.transforms_factory import create_transform
16 |
17 | config = resolve_data_config({}, model=model)
18 | transform = create_transform(**config)
19 |
20 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
21 | urllib.request.urlretrieve(url, filename)
22 | img = Image.open(filename).convert('RGB')
23 | tensor = transform(img).unsqueeze(0) # transform and add batch dimension
24 | ```
25 |
26 | To get the model predictions:
27 | ```python
28 | import torch
29 | with torch.no_grad():
30 | out = model(tensor)
31 | probabilities = torch.nn.functional.softmax(out[0], dim=0)
32 | print(probabilities.shape)
33 | # prints: torch.Size([1000])
34 | ```
35 |
36 | To get the top-5 predictions class names:
37 | ```python
38 | # Get imagenet class mappings
39 | url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
40 | urllib.request.urlretrieve(url, filename)
41 | with open("imagenet_classes.txt", "r") as f:
42 | categories = [s.strip() for s in f.readlines()]
43 |
44 | # Print top categories per image
45 | top5_prob, top5_catid = torch.topk(probabilities, 5)
46 | for i in range(top5_prob.size(0)):
47 | print(categories[top5_catid[i]], top5_prob[i].item())
48 | # prints class names and probabilities like:
49 | # [('Samoyed', 0.6425196528434753), ('Pomeranian', 0.04062102362513542), ('keeshond', 0.03186424449086189), ('white wolf', 0.01739676296710968), ('Eskimo dog', 0.011717947199940681)]
50 | ```
51 |
52 | Replace the model name with the variant you want to use, e.g. `{{ model_name }}`. You can find the IDs in the model summaries at the top of this page.
53 |
54 | To extract image features with this model, follow the [timm feature extraction examples](https://rwightman.github.io/pytorch-image-models/feature_extraction/), just change the name of the model you want to use.
55 |
56 | ## How do I finetune this model?
57 | You can finetune any of the pre-trained models just by changing the classifier (the last layer).
58 | ```python
59 | model = timm.create_model('{{ model_name }}', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
60 | ```
61 | To finetune on your own dataset, you have to write a training loop or adapt [timm's training
62 | script](https://github.com/rwightman/pytorch-image-models/blob/master/train.py) to use your dataset.
63 |
--------------------------------------------------------------------------------
/timm/layers/global_context.py:
--------------------------------------------------------------------------------
1 | """ Global Context Attention Block
2 |
3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
4 | - https://arxiv.org/abs/1904.11492
5 |
6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from torch import nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .create_act import create_act_layer, get_act_layer
14 | from .helpers import make_divisible
15 | from .mlp import ConvMlp
16 | from .norm import LayerNorm2d
17 |
18 |
19 | class GlobalContext(nn.Module):
20 |
21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
23 | super(GlobalContext, self).__init__()
24 | act_layer = get_act_layer(act_layer)
25 |
26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
27 |
28 | if rd_channels is None:
29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
30 | if fuse_add:
31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
32 | else:
33 | self.mlp_add = None
34 | if fuse_scale:
35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
36 | else:
37 | self.mlp_scale = None
38 |
39 | self.gate = create_act_layer(gate_layer)
40 | self.init_last_zero = init_last_zero
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | if self.conv_attn is not None:
45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
46 | if self.mlp_add is not None:
47 | nn.init.zeros_(self.mlp_add.fc2.weight)
48 |
49 | def forward(self, x):
50 | B, C, H, W = x.shape
51 |
52 | if self.conv_attn is not None:
53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
56 | context = context.view(B, C, 1, 1)
57 | else:
58 | context = x.mean(dim=(2, 3), keepdim=True)
59 |
60 | if self.mlp_scale is not None:
61 | mlp_x = self.mlp_scale(context)
62 | x = x * self.gate(mlp_x)
63 | if self.mlp_add is not None:
64 | mlp_x = self.mlp_add(context)
65 | x = x + mlp_x
66 |
67 | return x
68 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/csp-resnext.md:
--------------------------------------------------------------------------------
1 | # CSP-ResNeXt
2 |
3 | **CSPResNeXt** is a convolutional neural network where we apply the Cross Stage Partial Network (CSPNet) approach to [ResNeXt](https://paperswithcode.com/method/resnext). The CSPNet partitions the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @misc{wang2019cspnet,
15 | title={CSPNet: A New Backbone that can Enhance Learning Capability of CNN},
16 | author={Chien-Yao Wang and Hong-Yuan Mark Liao and I-Hau Yeh and Yueh-Hua Wu and Ping-Yang Chen and Jun-Wei Hsieh},
17 | year={2019},
18 | eprint={1911.11929},
19 | archivePrefix={arXiv},
20 | primaryClass={cs.CV}
21 | }
22 | ```
23 |
24 |
78 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/fbnet.md:
--------------------------------------------------------------------------------
1 | # FBNet
2 |
3 | **FBNet** is a type of convolutional neural architectures discovered through [DNAS](https://paperswithcode.com/method/dnas) neural architecture search. It utilises a basic type of image model block inspired by [MobileNetv2](https://paperswithcode.com/method/mobilenetv2) that utilises depthwise convolutions and an inverted residual structure (see components).
4 |
5 | The principal building block is the [FBNet Block](https://paperswithcode.com/method/fbnet-block).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{wu2019fbnet,
17 | title={FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable Neural Architecture Search},
18 | author={Bichen Wu and Xiaoliang Dai and Peizhao Zhang and Yanghan Wang and Fei Sun and Yiming Wu and Yuandong Tian and Peter Vajda and Yangqing Jia and Kurt Keutzer},
19 | year={2019},
20 | eprint={1812.03443},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
77 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_hfds.py:
--------------------------------------------------------------------------------
1 | """ Dataset reader that wraps Hugging Face datasets
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import io
6 | import math
7 | import torch
8 | import torch.distributed as dist
9 | from PIL import Image
10 |
11 | try:
12 | import datasets
13 | except ImportError as e:
14 | print("Please install Hugging Face datasets package `pip install datasets`.")
15 | exit(1)
16 | from .class_map import load_class_map
17 | from .reader import Reader
18 |
19 |
20 | def get_class_labels(info, label_key='label'):
21 | if 'label' not in info.features:
22 | return {}
23 | class_label = info.features[label_key]
24 | class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
25 | return class_to_idx
26 |
27 |
28 | class ReaderHfds(Reader):
29 |
30 | def __init__(
31 | self,
32 | root,
33 | name,
34 | split='train',
35 | class_map=None,
36 | label_key='label',
37 | download=False,
38 | ):
39 | """
40 | """
41 | super().__init__()
42 | self.root = root
43 | self.split = split
44 | self.dataset = datasets.load_dataset(
45 | name, # 'name' maps to path arg in hf datasets
46 | split=split,
47 | cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
48 | )
49 | # leave decode for caller, plus we want easy access to original path names...
50 | self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
51 |
52 | self.label_key = label_key
53 | self.remap_class = False
54 | if class_map:
55 | self.class_to_idx = load_class_map(class_map)
56 | self.remap_class = True
57 | else:
58 | self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
59 | self.split_info = self.dataset.info.splits[split]
60 | self.num_samples = self.split_info.num_examples
61 |
62 | def __getitem__(self, index):
63 | item = self.dataset[index]
64 | image = item['image']
65 | if 'bytes' in image and image['bytes']:
66 | image = io.BytesIO(image['bytes'])
67 | else:
68 | assert 'path' in image and image['path']
69 | image = open(image['path'], 'rb')
70 | label = item[self.label_key]
71 | if self.remap_class:
72 | label = self.class_to_idx[label]
73 | return image, label
74 |
75 | def __len__(self):
76 | return len(self.dataset)
77 |
78 | def _filename(self, index, basename=False, absolute=False):
79 | item = self.dataset[index]
80 | return item['image']['path']
81 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/res2next.md:
--------------------------------------------------------------------------------
1 | # Res2NeXt
2 |
3 | **Res2NeXt** is an image model that employs a variation on [ResNeXt](https://paperswithcode.com/method/resnext) bottleneck residual blocks. The motivation is to be able to represent features at multiple scales. This is achieved through a novel building block for CNNs that constructs hierarchical residual-like connections within one single residual block. This represents multi-scale features at a granular level and increases the range of receptive fields for each network layer.
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @article{Gao_2021,
15 | title={Res2Net: A New Multi-Scale Backbone Architecture},
16 | volume={43},
17 | ISSN={1939-3539},
18 | url={http://dx.doi.org/10.1109/TPAMI.2019.2938758},
19 | DOI={10.1109/tpami.2019.2938758},
20 | number={2},
21 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
22 | publisher={Institute of Electrical and Electronics Engineers (IEEE)},
23 | author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip},
24 | year={2021},
25 | month={Feb},
26 | pages={652–662}
27 | }
28 | ```
29 |
30 |
76 |
--------------------------------------------------------------------------------
/timm/layers/filter_response_norm.py:
--------------------------------------------------------------------------------
1 | """ Filter Response Norm in PyTorch
2 |
3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .create_act import create_act_layer
11 | from .trace_utils import _assert
12 |
13 |
14 | def inv_instance_rms(x, eps: float = 1e-5):
15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
16 | return rms.expand(x.shape)
17 |
18 |
19 | class FilterResponseNormTlu2d(nn.Module):
20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
21 | super(FilterResponseNormTlu2d, self).__init__()
22 | self.apply_act = apply_act # apply activation (non-linearity)
23 | self.rms = rms
24 | self.eps = eps
25 | self.weight = nn.Parameter(torch.ones(num_features))
26 | self.bias = nn.Parameter(torch.zeros(num_features))
27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | nn.init.ones_(self.weight)
32 | nn.init.zeros_(self.bias)
33 | if self.tau is not None:
34 | nn.init.zeros_(self.tau)
35 |
36 | def forward(self, x):
37 | _assert(x.dim() == 4, 'expected 4D input')
38 | x_dtype = x.dtype
39 | v_shape = (1, -1, 1, 1)
40 | x = x * inv_instance_rms(x, self.eps)
41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
43 |
44 |
45 | class FilterResponseNormAct2d(nn.Module):
46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
47 | super(FilterResponseNormAct2d, self).__init__()
48 | if act_layer is not None and apply_act:
49 | self.act = create_act_layer(act_layer, inplace=inplace)
50 | else:
51 | self.act = nn.Identity()
52 | self.rms = rms
53 | self.eps = eps
54 | self.weight = nn.Parameter(torch.ones(num_features))
55 | self.bias = nn.Parameter(torch.zeros(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.ones_(self.weight)
60 | nn.init.zeros_(self.bias)
61 |
62 | def forward(self, x):
63 | _assert(x.dim() == 4, 'expected 4D input')
64 | x_dtype = x.dtype
65 | v_shape = (1, -1, 1, 1)
66 | x = x * inv_instance_rms(x, self.eps)
67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
68 | return self.act(x)
69 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/csp-darknet.md:
--------------------------------------------------------------------------------
1 | # CSP-DarkNet
2 |
3 | **CSPDarknet53** is a convolutional neural network and backbone for object detection that uses [DarkNet-53](https://paperswithcode.com/method/darknet-53). It employs a CSPNet strategy to partition the feature map of the base layer into two parts and then merges them through a cross-stage hierarchy. The use of a split and merge strategy allows for more gradient flow through the network.
4 |
5 | This CNN is used as the backbone for [YOLOv4](https://paperswithcode.com/method/yolov4).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{bochkovskiy2020yolov4,
17 | title={YOLOv4: Optimal Speed and Accuracy of Object Detection},
18 | author={Alexey Bochkovskiy and Chien-Yao Wang and Hong-Yuan Mark Liao},
19 | year={2020},
20 | eprint={2004.10934},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
82 |
--------------------------------------------------------------------------------
/results/generate_csv_results.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | results = {
6 | 'results-imagenet.csv': [
7 | 'results-imagenet-real.csv',
8 | 'results-imagenetv2-matched-frequency.csv',
9 | 'results-sketch.csv'
10 | ],
11 | 'results-imagenet-a-clean.csv': [
12 | 'results-imagenet-a.csv',
13 | ],
14 | 'results-imagenet-r-clean.csv': [
15 | 'results-imagenet-r.csv',
16 | ],
17 | }
18 |
19 |
20 | def diff(base_df, test_csv):
21 | base_models = base_df['model'].values
22 | test_df = pd.read_csv(test_csv)
23 | test_models = test_df['model'].values
24 |
25 | rank_diff = np.zeros_like(test_models, dtype='object')
26 | top1_diff = np.zeros_like(test_models, dtype='object')
27 | top5_diff = np.zeros_like(test_models, dtype='object')
28 |
29 | for rank, model in enumerate(test_models):
30 | if model in base_models:
31 | base_rank = int(np.where(base_models == model)[0])
32 | top1_d = test_df['top1'][rank] - base_df['top1'][base_rank]
33 | top5_d = test_df['top5'][rank] - base_df['top5'][base_rank]
34 |
35 | # rank_diff
36 | if rank == base_rank:
37 | rank_diff[rank] = f'0'
38 | elif rank > base_rank:
39 | rank_diff[rank] = f'-{rank - base_rank}'
40 | else:
41 | rank_diff[rank] = f'+{base_rank - rank}'
42 |
43 | # top1_diff
44 | if top1_d >= .0:
45 | top1_diff[rank] = f'+{top1_d:.3f}'
46 | else:
47 | top1_diff[rank] = f'-{abs(top1_d):.3f}'
48 |
49 | # top5_diff
50 | if top5_d >= .0:
51 | top5_diff[rank] = f'+{top5_d:.3f}'
52 | else:
53 | top5_diff[rank] = f'-{abs(top5_d):.3f}'
54 |
55 | else:
56 | rank_diff[rank] = ''
57 | top1_diff[rank] = ''
58 | top5_diff[rank] = ''
59 |
60 | test_df['top1_diff'] = top1_diff
61 | test_df['top5_diff'] = top5_diff
62 | test_df['rank_diff'] = rank_diff
63 |
64 | test_df['param_count'] = test_df['param_count'].map('{:,.2f}'.format)
65 | test_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True)
66 | test_df.to_csv(test_csv, index=False, float_format='%.3f')
67 |
68 |
69 | for base_results, test_results in results.items():
70 | base_df = pd.read_csv(base_results)
71 | base_df.sort_values(['top1', 'top5', 'model'], ascending=[False, False, True], inplace=True)
72 | for test_csv in test_results:
73 | diff(base_df, test_csv)
74 | base_df['param_count'] = base_df['param_count'].map('{:,.2f}'.format)
75 | base_df.to_csv(base_results, index=False, float_format='%.3f')
76 |
--------------------------------------------------------------------------------
/timm/layers/activations_jit.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of jit-scripted activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
8 | versions if they contain in-place ops.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 |
13 | import torch
14 | from torch import nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | @torch.jit.script
19 | def swish_jit(x, inplace: bool = False):
20 | """Swish - Described in: https://arxiv.org/abs/1710.05941
21 | """
22 | return x.mul(x.sigmoid())
23 |
24 |
25 | @torch.jit.script
26 | def mish_jit(x, _inplace: bool = False):
27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
28 | """
29 | return x.mul(F.softplus(x).tanh())
30 |
31 |
32 | class SwishJit(nn.Module):
33 | def __init__(self, inplace: bool = False):
34 | super(SwishJit, self).__init__()
35 |
36 | def forward(self, x):
37 | return swish_jit(x)
38 |
39 |
40 | class MishJit(nn.Module):
41 | def __init__(self, inplace: bool = False):
42 | super(MishJit, self).__init__()
43 |
44 | def forward(self, x):
45 | return mish_jit(x)
46 |
47 |
48 | @torch.jit.script
49 | def hard_sigmoid_jit(x, inplace: bool = False):
50 | # return F.relu6(x + 3.) / 6.
51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
52 |
53 |
54 | class HardSigmoidJit(nn.Module):
55 | def __init__(self, inplace: bool = False):
56 | super(HardSigmoidJit, self).__init__()
57 |
58 | def forward(self, x):
59 | return hard_sigmoid_jit(x)
60 |
61 |
62 | @torch.jit.script
63 | def hard_swish_jit(x, inplace: bool = False):
64 | # return x * (F.relu6(x + 3.) / 6)
65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
66 |
67 |
68 | class HardSwishJit(nn.Module):
69 | def __init__(self, inplace: bool = False):
70 | super(HardSwishJit, self).__init__()
71 |
72 | def forward(self, x):
73 | return hard_swish_jit(x)
74 |
75 |
76 | @torch.jit.script
77 | def hard_mish_jit(x, inplace: bool = False):
78 | """ Hard Mish
79 | Experimental, based on notes by Mish author Diganta Misra at
80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
81 | """
82 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
83 |
84 |
85 | class HardMishJit(nn.Module):
86 | def __init__(self, inplace: bool = False):
87 | super(HardMishJit, self).__init__()
88 |
89 | def forward(self, x):
90 | return hard_mish_jit(x)
91 |
--------------------------------------------------------------------------------
/timm/layers/pos_embed.py:
--------------------------------------------------------------------------------
1 | """ Position Embedding Utilities
2 |
3 | Hacked together by / Copyright 2022 Ross Wightman
4 | """
5 | import logging
6 | import math
7 | from typing import List, Tuple, Optional, Union
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .helpers import to_2tuple
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | def resample_abs_pos_embed(
18 | posemb,
19 | new_size: List[int],
20 | old_size: Optional[List[int]] = None,
21 | num_prefix_tokens: int = 1,
22 | interpolation: str = 'bicubic',
23 | antialias: bool = True,
24 | verbose: bool = False,
25 | ):
26 | # sort out sizes, assume square if old size not provided
27 | num_pos_tokens = posemb.shape[1]
28 | num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
29 | if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
30 | return posemb
31 |
32 | if old_size is None:
33 | hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
34 | old_size = hw, hw
35 |
36 | if num_prefix_tokens:
37 | posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
38 | else:
39 | posemb_prefix, posemb = None, posemb
40 |
41 | # do the interpolation
42 | embed_dim = posemb.shape[-1]
43 | orig_dtype = posemb.dtype
44 | posemb = posemb.float() # interpolate needs float32
45 | posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
46 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
47 | posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
48 | posemb = posemb.to(orig_dtype)
49 |
50 | # add back extra (class, etc) prefix tokens
51 | if posemb_prefix is not None:
52 | posemb = torch.cat([posemb_prefix, posemb], dim=1)
53 |
54 | if not torch.jit.is_scripting() and verbose:
55 | _logger.info(f'Resized position embedding: {old_size} to {new_size}.')
56 |
57 | return posemb
58 |
59 |
60 | def resample_abs_pos_embed_nhwc(
61 | posemb,
62 | new_size: List[int],
63 | interpolation: str = 'bicubic',
64 | antialias: bool = True,
65 | verbose: bool = False,
66 | ):
67 | if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
68 | return posemb
69 |
70 | orig_dtype = posemb.dtype
71 | posemb = posemb.float()
72 | # do the interpolation
73 | posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
74 | posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
75 | posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
76 |
77 | if not torch.jit.is_scripting() and verbose:
78 | _logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
79 |
80 | return posemb
81 |
--------------------------------------------------------------------------------
/timm/layers/separable_conv.py:
--------------------------------------------------------------------------------
1 | """ Depthwise Separable Conv Modules
2 |
3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | from torch import nn as nn
9 |
10 | from .create_conv2d import create_conv2d
11 | from .create_norm_act import get_norm_act_layer
12 |
13 |
14 | class SeparableConvNormAct(nn.Module):
15 | """ Separable Conv w/ trailing Norm and Activation
16 | """
17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
19 | apply_act=True, drop_layer=None):
20 | super(SeparableConvNormAct, self).__init__()
21 |
22 | self.conv_dw = create_conv2d(
23 | in_channels, int(in_channels * channel_multiplier), kernel_size,
24 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
25 |
26 | self.conv_pw = create_conv2d(
27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
28 |
29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
32 |
33 | @property
34 | def in_channels(self):
35 | return self.conv_dw.in_channels
36 |
37 | @property
38 | def out_channels(self):
39 | return self.conv_pw.out_channels
40 |
41 | def forward(self, x):
42 | x = self.conv_dw(x)
43 | x = self.conv_pw(x)
44 | x = self.bn(x)
45 | return x
46 |
47 |
48 | SeparableConvBnAct = SeparableConvNormAct
49 |
50 |
51 | class SeparableConv2d(nn.Module):
52 | """ Separable Conv
53 | """
54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
55 | channel_multiplier=1.0, pw_kernel_size=1):
56 | super(SeparableConv2d, self).__init__()
57 |
58 | self.conv_dw = create_conv2d(
59 | in_channels, int(in_channels * channel_multiplier), kernel_size,
60 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
61 |
62 | self.conv_pw = create_conv2d(
63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
64 |
65 | @property
66 | def in_channels(self):
67 | return self.conv_dw.in_channels
68 |
69 | @property
70 | def out_channels(self):
71 | return self.conv_pw.out_channels
72 |
73 | def forward(self, x):
74 | x = self.conv_dw(x)
75 | x = self.conv_pw(x)
76 | return x
77 |
--------------------------------------------------------------------------------
/timm/optim/lookahead.py:
--------------------------------------------------------------------------------
1 | """ Lookahead Optimizer Wrapper.
2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch
3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | from collections import OrderedDict
8 | from typing import Callable, Dict
9 |
10 | import torch
11 | from torch.optim.optimizer import Optimizer
12 | from collections import defaultdict
13 |
14 |
15 | class Lookahead(Optimizer):
16 | def __init__(self, base_optimizer, alpha=0.5, k=6):
17 | # NOTE super().__init__() not called on purpose
18 | self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
19 | self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
20 | if not 0.0 <= alpha <= 1.0:
21 | raise ValueError(f'Invalid slow update rate: {alpha}')
22 | if not 1 <= k:
23 | raise ValueError(f'Invalid lookahead steps: {k}')
24 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
25 | self._base_optimizer = base_optimizer
26 | self.param_groups = base_optimizer.param_groups
27 | self.defaults = base_optimizer.defaults
28 | self.defaults.update(defaults)
29 | self.state = defaultdict(dict)
30 | # manually add our defaults to the param groups
31 | for name, default in defaults.items():
32 | for group in self._base_optimizer.param_groups:
33 | group.setdefault(name, default)
34 |
35 | @torch.no_grad()
36 | def update_slow(self, group):
37 | for fast_p in group["params"]:
38 | if fast_p.grad is None:
39 | continue
40 | param_state = self._base_optimizer.state[fast_p]
41 | if 'lookahead_slow_buff' not in param_state:
42 | param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
43 | param_state['lookahead_slow_buff'].copy_(fast_p)
44 | slow = param_state['lookahead_slow_buff']
45 | slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
46 | fast_p.copy_(slow)
47 |
48 | def sync_lookahead(self):
49 | for group in self._base_optimizer.param_groups:
50 | self.update_slow(group)
51 |
52 | @torch.no_grad()
53 | def step(self, closure=None):
54 | loss = self._base_optimizer.step(closure)
55 | for group in self._base_optimizer.param_groups:
56 | group['lookahead_step'] += 1
57 | if group['lookahead_step'] % group['lookahead_k'] == 0:
58 | self.update_slow(group)
59 | return loss
60 |
61 | def state_dict(self):
62 | return self._base_optimizer.state_dict()
63 |
64 | def load_state_dict(self, state_dict):
65 | self._base_optimizer.load_state_dict(state_dict)
66 | self.param_groups = self._base_optimizer.param_groups
67 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_image_tar.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that reads single tarfile based datasets
2 |
3 | This reader can read datasets consisting if a single tarfile containing images.
4 | I am planning to deprecated it in favour of ParerImageInTar.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | import tarfile
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
19 | extensions = get_img_extensions(as_set=True)
20 | files = []
21 | labels = []
22 | for ti in tarfile.getmembers():
23 | if not ti.isfile():
24 | continue
25 | dirname, basename = os.path.split(ti.path)
26 | label = os.path.basename(dirname)
27 | ext = os.path.splitext(basename)[1]
28 | if ext.lower() in extensions:
29 | files.append(ti)
30 | labels.append(label)
31 | if class_to_idx is None:
32 | unique_labels = set(labels)
33 | sorted_labels = list(sorted(unique_labels, key=natural_key))
34 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
35 | tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
36 | if sort:
37 | tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
38 | return tarinfo_and_targets, class_to_idx
39 |
40 |
41 | class ReaderImageTar(Reader):
42 | """ Single tarfile dataset where classes are mapped to folders within tar
43 | NOTE: This class is being deprecated in favour of the more capable ReaderImageInTar that can
44 | operate on folders of tars or tars in tars.
45 | """
46 | def __init__(self, root, class_map=''):
47 | super().__init__()
48 |
49 | class_to_idx = None
50 | if class_map:
51 | class_to_idx = load_class_map(class_map, root)
52 | assert os.path.isfile(root)
53 | self.root = root
54 |
55 | with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
56 | self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
57 | self.imgs = self.samples
58 | self.tarfile = None # lazy init in __getitem__
59 |
60 | def __getitem__(self, index):
61 | if self.tarfile is None:
62 | self.tarfile = tarfile.open(self.root)
63 | tarinfo, target = self.samples[index]
64 | fileobj = self.tarfile.extractfile(tarinfo)
65 | return fileobj, target
66 |
67 | def __len__(self):
68 | return len(self.samples)
69 |
70 | def _filename(self, index, basename=False, absolute=False):
71 | filename = self.samples[index][0].name
72 | if basename:
73 | filename = os.path.basename(filename)
74 | return filename
75 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/gloun-inception-v3.md:
--------------------------------------------------------------------------------
1 | # (Gluon) Inception v3
2 |
3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module).
4 |
5 | The weights from this model were ported from [Gluon](https://cv.gluon.ai/model_zoo/classification.html).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @article{DBLP:journals/corr/SzegedyVISW15,
17 | author = {Christian Szegedy and
18 | Vincent Vanhoucke and
19 | Sergey Ioffe and
20 | Jonathon Shlens and
21 | Zbigniew Wojna},
22 | title = {Rethinking the Inception Architecture for Computer Vision},
23 | journal = {CoRR},
24 | volume = {abs/1512.00567},
25 | year = {2015},
26 | url = {http://arxiv.org/abs/1512.00567},
27 | archivePrefix = {arXiv},
28 | eprint = {1512.00567},
29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200},
30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib},
31 | bibsource = {dblp computer science bibliography, https://dblp.org}
32 | }
33 | ```
34 |
35 |
79 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_a_synsets.txt:
--------------------------------------------------------------------------------
1 | n01498041
2 | n01531178
3 | n01534433
4 | n01558993
5 | n01580077
6 | n01614925
7 | n01616318
8 | n01631663
9 | n01641577
10 | n01669191
11 | n01677366
12 | n01687978
13 | n01694178
14 | n01698640
15 | n01735189
16 | n01770081
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01819313
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01882714
27 | n01910747
28 | n01914609
29 | n01924916
30 | n01944390
31 | n01985128
32 | n01986214
33 | n02007558
34 | n02009912
35 | n02037110
36 | n02051845
37 | n02077923
38 | n02085620
39 | n02099601
40 | n02106550
41 | n02106662
42 | n02110958
43 | n02119022
44 | n02123394
45 | n02127052
46 | n02129165
47 | n02133161
48 | n02137549
49 | n02165456
50 | n02174001
51 | n02177972
52 | n02190166
53 | n02206856
54 | n02219486
55 | n02226429
56 | n02231487
57 | n02233338
58 | n02236044
59 | n02259212
60 | n02268443
61 | n02279972
62 | n02280649
63 | n02281787
64 | n02317335
65 | n02325366
66 | n02346627
67 | n02356798
68 | n02361337
69 | n02410509
70 | n02445715
71 | n02454379
72 | n02486410
73 | n02492035
74 | n02504458
75 | n02655020
76 | n02669723
77 | n02672831
78 | n02676566
79 | n02690373
80 | n02701002
81 | n02730930
82 | n02777292
83 | n02782093
84 | n02787622
85 | n02793495
86 | n02797295
87 | n02802426
88 | n02814860
89 | n02815834
90 | n02837789
91 | n02879718
92 | n02883205
93 | n02895154
94 | n02906734
95 | n02948072
96 | n02951358
97 | n02980441
98 | n02992211
99 | n02999410
100 | n03014705
101 | n03026506
102 | n03124043
103 | n03125729
104 | n03187595
105 | n03196217
106 | n03223299
107 | n03250847
108 | n03255030
109 | n03291819
110 | n03325584
111 | n03355925
112 | n03384352
113 | n03388043
114 | n03417042
115 | n03443371
116 | n03444034
117 | n03445924
118 | n03452741
119 | n03483316
120 | n03584829
121 | n03590841
122 | n03594945
123 | n03617480
124 | n03666591
125 | n03670208
126 | n03717622
127 | n03720891
128 | n03721384
129 | n03724870
130 | n03775071
131 | n03788195
132 | n03804744
133 | n03837869
134 | n03840681
135 | n03854065
136 | n03888257
137 | n03891332
138 | n03935335
139 | n03982430
140 | n04019541
141 | n04033901
142 | n04039381
143 | n04067472
144 | n04086273
145 | n04099969
146 | n04118538
147 | n04131690
148 | n04133789
149 | n04141076
150 | n04146614
151 | n04147183
152 | n04179913
153 | n04208210
154 | n04235860
155 | n04252077
156 | n04252225
157 | n04254120
158 | n04270147
159 | n04275548
160 | n04310018
161 | n04317175
162 | n04344873
163 | n04347754
164 | n04355338
165 | n04366367
166 | n04376876
167 | n04389033
168 | n04399382
169 | n04442312
170 | n04456115
171 | n04482393
172 | n04507155
173 | n04509417
174 | n04532670
175 | n04540053
176 | n04554684
177 | n04562935
178 | n04591713
179 | n04606251
180 | n07583066
181 | n07695742
182 | n07697313
183 | n07697537
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07749582
189 | n07753592
190 | n07760859
191 | n07768694
192 | n07831146
193 | n09229709
194 | n09246464
195 | n09472597
196 | n09835506
197 | n11879895
198 | n12057211
199 | n12144580
200 | n12267677
201 |
--------------------------------------------------------------------------------
/timm/data/_info/imagenet_r_synsets.txt:
--------------------------------------------------------------------------------
1 | n01443537
2 | n01484850
3 | n01494475
4 | n01498041
5 | n01514859
6 | n01518878
7 | n01531178
8 | n01534433
9 | n01614925
10 | n01616318
11 | n01630670
12 | n01632777
13 | n01644373
14 | n01677366
15 | n01694178
16 | n01748264
17 | n01770393
18 | n01774750
19 | n01784675
20 | n01806143
21 | n01820546
22 | n01833805
23 | n01843383
24 | n01847000
25 | n01855672
26 | n01860187
27 | n01882714
28 | n01910747
29 | n01944390
30 | n01983481
31 | n01986214
32 | n02007558
33 | n02009912
34 | n02051845
35 | n02056570
36 | n02066245
37 | n02071294
38 | n02077923
39 | n02085620
40 | n02086240
41 | n02088094
42 | n02088238
43 | n02088364
44 | n02088466
45 | n02091032
46 | n02091134
47 | n02092339
48 | n02094433
49 | n02096585
50 | n02097298
51 | n02098286
52 | n02099601
53 | n02099712
54 | n02102318
55 | n02106030
56 | n02106166
57 | n02106550
58 | n02106662
59 | n02108089
60 | n02108915
61 | n02109525
62 | n02110185
63 | n02110341
64 | n02110958
65 | n02112018
66 | n02112137
67 | n02113023
68 | n02113624
69 | n02113799
70 | n02114367
71 | n02117135
72 | n02119022
73 | n02123045
74 | n02128385
75 | n02128757
76 | n02129165
77 | n02129604
78 | n02130308
79 | n02134084
80 | n02138441
81 | n02165456
82 | n02190166
83 | n02206856
84 | n02219486
85 | n02226429
86 | n02233338
87 | n02236044
88 | n02268443
89 | n02279972
90 | n02317335
91 | n02325366
92 | n02346627
93 | n02356798
94 | n02363005
95 | n02364673
96 | n02391049
97 | n02395406
98 | n02398521
99 | n02410509
100 | n02423022
101 | n02437616
102 | n02445715
103 | n02447366
104 | n02480495
105 | n02480855
106 | n02481823
107 | n02483362
108 | n02486410
109 | n02510455
110 | n02526121
111 | n02607072
112 | n02655020
113 | n02672831
114 | n02701002
115 | n02749479
116 | n02769748
117 | n02793495
118 | n02797295
119 | n02802426
120 | n02808440
121 | n02814860
122 | n02823750
123 | n02841315
124 | n02843684
125 | n02883205
126 | n02906734
127 | n02909870
128 | n02939185
129 | n02948072
130 | n02950826
131 | n02951358
132 | n02966193
133 | n02980441
134 | n02992529
135 | n03124170
136 | n03272010
137 | n03345487
138 | n03372029
139 | n03424325
140 | n03452741
141 | n03467068
142 | n03481172
143 | n03494278
144 | n03495258
145 | n03498962
146 | n03594945
147 | n03602883
148 | n03630383
149 | n03649909
150 | n03676483
151 | n03710193
152 | n03773504
153 | n03775071
154 | n03888257
155 | n03930630
156 | n03947888
157 | n04086273
158 | n04118538
159 | n04133789
160 | n04141076
161 | n04146614
162 | n04147183
163 | n04192698
164 | n04254680
165 | n04266014
166 | n04275548
167 | n04310018
168 | n04325704
169 | n04347754
170 | n04389033
171 | n04409515
172 | n04465501
173 | n04487394
174 | n04522168
175 | n04536866
176 | n04552348
177 | n04591713
178 | n07614500
179 | n07693725
180 | n07695742
181 | n07697313
182 | n07697537
183 | n07714571
184 | n07714990
185 | n07718472
186 | n07720875
187 | n07734744
188 | n07742313
189 | n07745940
190 | n07749582
191 | n07753275
192 | n07753592
193 | n07768694
194 | n07873807
195 | n07880968
196 | n07920052
197 | n09472597
198 | n09835506
199 | n10565667
200 | n12267677
201 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/inception-v3.md:
--------------------------------------------------------------------------------
1 | # Inception v3
2 |
3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module).
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @article{DBLP:journals/corr/SzegedyVISW15,
15 | author = {Christian Szegedy and
16 | Vincent Vanhoucke and
17 | Sergey Ioffe and
18 | Jonathon Shlens and
19 | Zbigniew Wojna},
20 | title = {Rethinking the Inception Architecture for Computer Vision},
21 | journal = {CoRR},
22 | volume = {abs/1512.00567},
23 | year = {2015},
24 | url = {http://arxiv.org/abs/1512.00567},
25 | archivePrefix = {arXiv},
26 | eprint = {1512.00567},
27 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200},
28 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib},
29 | bibsource = {dblp computer science bibliography, https://dblp.org}
30 | }
31 | ```
32 |
33 |
86 |
--------------------------------------------------------------------------------
/timm/layers/padding.py:
--------------------------------------------------------------------------------
1 | """ Padding Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import math
6 | from typing import List, Tuple
7 |
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | # Calculate symmetric padding for a convolution
13 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
14 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
15 | return padding
16 |
17 |
18 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
19 | def get_same_padding(x: int, kernel_size: int, stride: int, dilation: int):
20 | if isinstance(x, torch.Tensor):
21 | return torch.clamp(((x / stride).ceil() - 1) * stride + (kernel_size - 1) * dilation + 1 - x, min=0)
22 | else:
23 | return max((math.ceil(x / stride) - 1) * stride + (kernel_size - 1) * dilation + 1 - x, 0)
24 |
25 |
26 | # Can SAME padding for given args be done statically?
27 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
28 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
29 |
30 |
31 | def pad_same_arg(
32 | input_size: List[int],
33 | kernel_size: List[int],
34 | stride: List[int],
35 | dilation: List[int] = (1, 1),
36 | ) -> List[int]:
37 | ih, iw = input_size
38 | kh, kw = kernel_size
39 | pad_h = get_same_padding(ih, kh, stride[0], dilation[0])
40 | pad_w = get_same_padding(iw, kw, stride[1], dilation[1])
41 | return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
42 |
43 |
44 | # Dynamically pad input x with 'SAME' padding for conv with specified args
45 | def pad_same(
46 | x,
47 | kernel_size: List[int],
48 | stride: List[int],
49 | dilation: List[int] = (1, 1),
50 | value: float = 0,
51 | ):
52 | ih, iw = x.size()[-2:]
53 | pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
54 | pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
55 | x = F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), value=value)
56 | return x
57 |
58 |
59 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
60 | dynamic = False
61 | if isinstance(padding, str):
62 | # for any string padding, the padding will be calculated for you, one of three ways
63 | padding = padding.lower()
64 | if padding == 'same':
65 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
66 | if is_static_pad(kernel_size, **kwargs):
67 | # static case, no extra overhead
68 | padding = get_padding(kernel_size, **kwargs)
69 | else:
70 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
71 | padding = 0
72 | dynamic = True
73 | elif padding == 'valid':
74 | # 'VALID' padding, same as padding=0
75 | padding = 0
76 | else:
77 | # Default to PyTorch style 'same'-ish symmetric padding
78 | padding = get_padding(kernel_size, **kwargs)
79 | return padding, dynamic
80 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/ese-vovnet.md:
--------------------------------------------------------------------------------
1 | # ESE-VoVNet
2 |
3 | **VoVNet** is a convolutional neural network that seeks to make [DenseNet](https://paperswithcode.com/method/densenet) more efficient by concatenating all features only once in the last feature map, which makes input size constant and enables enlarging new output channel.
4 |
5 | Read about [one-shot aggregation here](https://paperswithcode.com/method/one-shot-aggregation).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @misc{lee2019energy,
17 | title={An Energy and GPU-Computation Efficient Backbone Network for Real-Time Object Detection},
18 | author={Youngwan Lee and Joong-won Hwang and Sangrok Lee and Yuseok Bae and Jongyoul Park},
19 | year={2019},
20 | eprint={1904.09730},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.CV}
23 | }
24 | ```
25 |
26 |
93 |
--------------------------------------------------------------------------------
/timm/layers/pool2d_same.py:
--------------------------------------------------------------------------------
1 | """ AvgPool2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import List, Tuple, Optional
9 |
10 | from .helpers import to_2tuple
11 | from .padding import pad_same, get_padding_value
12 |
13 |
14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15 | ceil_mode: bool = False, count_include_pad: bool = True):
16 | # FIXME how to deal with count_include_pad vs not for external padding?
17 | x = pad_same(x, kernel_size, stride)
18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19 |
20 |
21 | class AvgPool2dSame(nn.AvgPool2d):
22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling
23 | """
24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25 | kernel_size = to_2tuple(kernel_size)
26 | stride = to_2tuple(stride)
27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28 |
29 | def forward(self, x):
30 | x = pad_same(x, self.kernel_size, self.stride)
31 | return F.avg_pool2d(
32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 |
35 | def max_pool2d_same(
36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37 | dilation: List[int] = (1, 1), ceil_mode: bool = False):
38 | x = pad_same(x, kernel_size, stride, value=-float('inf'))
39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
40 |
41 |
42 | class MaxPool2dSame(nn.MaxPool2d):
43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling
44 | """
45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
46 | kernel_size = to_2tuple(kernel_size)
47 | stride = to_2tuple(stride)
48 | dilation = to_2tuple(dilation)
49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
50 |
51 | def forward(self, x):
52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
54 |
55 |
56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
57 | stride = stride or kernel_size
58 | padding = kwargs.pop('padding', '')
59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
60 | if is_dynamic:
61 | if pool_type == 'avg':
62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
63 | elif pool_type == 'max':
64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
65 | else:
66 | assert False, f'Unsupported pool type {pool_type}'
67 | else:
68 | if pool_type == 'avg':
69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
70 | elif pool_type == 'max':
71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
72 | else:
73 | assert False, f'Unsupported pool type {pool_type}'
74 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/tf-inception-v3.md:
--------------------------------------------------------------------------------
1 | # (Tensorflow) Inception v3
2 |
3 | **Inception v3** is a convolutional neural network architecture from the Inception family that makes several improvements including using [Label Smoothing](https://paperswithcode.com/method/label-smoothing), Factorized 7 x 7 convolutions, and the use of an [auxiliary classifer](https://paperswithcode.com/method/auxiliary-classifier) to propagate label information lower down the network (along with the use of batch normalization for layers in the sidehead). The key building block is an [Inception Module](https://paperswithcode.com/method/inception-v3-module).
4 |
5 | The weights from this model were ported from [Tensorflow/Models](https://github.com/tensorflow/models).
6 |
7 | {% include 'code_snippets.md' %}
8 |
9 | ## How do I train this model?
10 |
11 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
12 |
13 | ## Citation
14 |
15 | ```BibTeX
16 | @article{DBLP:journals/corr/SzegedyVISW15,
17 | author = {Christian Szegedy and
18 | Vincent Vanhoucke and
19 | Sergey Ioffe and
20 | Jonathon Shlens and
21 | Zbigniew Wojna},
22 | title = {Rethinking the Inception Architecture for Computer Vision},
23 | journal = {CoRR},
24 | volume = {abs/1512.00567},
25 | year = {2015},
26 | url = {http://arxiv.org/abs/1512.00567},
27 | archivePrefix = {arXiv},
28 | eprint = {1512.00567},
29 | timestamp = {Mon, 13 Aug 2018 16:49:07 +0200},
30 | biburl = {https://dblp.org/rec/journals/corr/SzegedyVISW15.bib},
31 | bibsource = {dblp computer science bibliography, https://dblp.org}
32 | }
33 | ```
34 |
35 |
88 |
--------------------------------------------------------------------------------
/timm/layers/split_attn.py:
--------------------------------------------------------------------------------
1 | """ Split Attention Conv2d (for ResNeSt Models)
2 |
3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4 |
5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6 |
7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8 | """
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from .helpers import make_divisible
14 |
15 |
16 | class RadixSoftmax(nn.Module):
17 | def __init__(self, radix, cardinality):
18 | super(RadixSoftmax, self).__init__()
19 | self.radix = radix
20 | self.cardinality = cardinality
21 |
22 | def forward(self, x):
23 | batch = x.size(0)
24 | if self.radix > 1:
25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
26 | x = F.softmax(x, dim=1)
27 | x = x.reshape(batch, -1)
28 | else:
29 | x = torch.sigmoid(x)
30 | return x
31 |
32 |
33 | class SplitAttn(nn.Module):
34 | """Split-Attention (aka Splat)
35 | """
36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
39 | super(SplitAttn, self).__init__()
40 | out_channels = out_channels or in_channels
41 | self.radix = radix
42 | mid_chs = out_channels * radix
43 | if rd_channels is None:
44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
45 | else:
46 | attn_chs = rd_channels * radix
47 |
48 | padding = kernel_size // 2 if padding is None else padding
49 | self.conv = nn.Conv2d(
50 | in_channels, mid_chs, kernel_size, stride, padding, dilation,
51 | groups=groups * radix, bias=bias, **kwargs)
52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
54 | self.act0 = act_layer(inplace=True)
55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
57 | self.act1 = act_layer(inplace=True)
58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
59 | self.rsoftmax = RadixSoftmax(radix, groups)
60 |
61 | def forward(self, x):
62 | x = self.conv(x)
63 | x = self.bn0(x)
64 | x = self.drop(x)
65 | x = self.act0(x)
66 |
67 | B, RC, H, W = x.shape
68 | if self.radix > 1:
69 | x = x.reshape((B, self.radix, RC // self.radix, H, W))
70 | x_gap = x.sum(dim=1)
71 | else:
72 | x_gap = x
73 | x_gap = x_gap.mean((2, 3), keepdim=True)
74 | x_gap = self.fc1(x_gap)
75 | x_gap = self.bn1(x_gap)
76 | x_gap = self.act1(x_gap)
77 | x_attn = self.fc2(x_gap)
78 |
79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
80 | if self.radix > 1:
81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
82 | else:
83 | out = x * x_attn
84 | return out.contiguous()
85 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/wide-resnet.md:
--------------------------------------------------------------------------------
1 | # Wide ResNet
2 |
3 | **Wide Residual Networks** are a variant on [ResNets](https://paperswithcode.com/method/resnet) where we decrease depth and increase the width of residual networks. This is achieved through the use of [wide residual blocks](https://paperswithcode.com/method/wide-residual-block).
4 |
5 | {% include 'code_snippets.md' %}
6 |
7 | ## How do I train this model?
8 |
9 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
10 |
11 | ## Citation
12 |
13 | ```BibTeX
14 | @article{DBLP:journals/corr/ZagoruykoK16,
15 | author = {Sergey Zagoruyko and
16 | Nikos Komodakis},
17 | title = {Wide Residual Networks},
18 | journal = {CoRR},
19 | volume = {abs/1605.07146},
20 | year = {2016},
21 | url = {http://arxiv.org/abs/1605.07146},
22 | archivePrefix = {arXiv},
23 | eprint = {1605.07146},
24 | timestamp = {Mon, 13 Aug 2018 16:46:42 +0200},
25 | biburl = {https://dblp.org/rec/journals/corr/ZagoruykoK16.bib},
26 | bibsource = {dblp computer science bibliography, https://dblp.org}
27 | }
28 | ```
29 |
30 |
103 |
--------------------------------------------------------------------------------
/timm/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # NOTE timm.models.layers is DEPRECATED, please use timm.layers, this is here to reduce breakages in transition
2 | from timm.layers.activations import *
3 | from timm.layers.adaptive_avgmax_pool import \
4 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
5 | from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
6 | from timm.layers.blur_pool import BlurPool2d
7 | from timm.layers.classifier import ClassifierHead, create_classifier
8 | from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer
9 | from timm.layers.config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
10 | set_layer_config
11 | from timm.layers.conv2d_same import Conv2dSame, conv2d_same
12 | from timm.layers.conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
13 | from timm.layers.create_act import create_act_layer, get_act_layer, get_act_fn
14 | from timm.layers.create_attn import get_attn, create_attn
15 | from timm.layers.create_conv2d import create_conv2d
16 | from timm.layers.create_norm import get_norm_layer, create_norm_layer
17 | from timm.layers.create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
18 | from timm.layers.drop import DropBlock2d, DropPath, drop_block_2d, drop_path
19 | from timm.layers.eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
20 | from timm.layers.evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
21 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
22 | from timm.layers.fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
23 | from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
24 | from timm.layers.gather_excite import GatherExcite
25 | from timm.layers.global_context import GlobalContext
26 | from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
27 | from timm.layers.inplace_abn import InplaceAbn
28 | from timm.layers.linear import Linear
29 | from timm.layers.mixed_conv2d import MixedConv2d
30 | from timm.layers.mlp import Mlp, GluMlp, GatedMlp, ConvMlp
31 | from timm.layers.non_local_attn import NonLocalAttn, BatNonLocalAttn
32 | from timm.layers.norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
33 | from timm.layers.norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
34 | from timm.layers.padding import get_padding, get_same_padding, pad_same
35 | from timm.layers.patch_embed import PatchEmbed
36 | from timm.layers.pool2d_same import AvgPool2dSame, create_pool2d
37 | from timm.layers.squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
38 | from timm.layers.selective_kernel import SelectiveKernel
39 | from timm.layers.separable_conv import SeparableConv2d, SeparableConvNormAct
40 | from timm.layers.space_to_depth import SpaceToDepthModule
41 | from timm.layers.split_attn import SplitAttn
42 | from timm.layers.split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
43 | from timm.layers.std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
44 | from timm.layers.test_time_pool import TestTimePoolHead, apply_test_time_pool
45 | from timm.layers.trace_utils import _assert, _float_to_int
46 | from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
47 |
48 | import warnings
49 | warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", DeprecationWarning)
50 |
--------------------------------------------------------------------------------
/timm/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .beit import *
2 | from .byoanet import *
3 | from .byobnet import *
4 | from .cait import *
5 | from .coat import *
6 | from .convit import *
7 | from .convmixer import *
8 | from .convnext import *
9 | from .crossvit import *
10 | from .cspnet import *
11 | from .davit import *
12 | from .deit import *
13 | from .densenet import *
14 | from .dla import *
15 | from .dpn import *
16 | from .edgenext import *
17 | from .efficientformer import *
18 | from .efficientformer_v2 import *
19 | from .efficientnet import *
20 | from .efficientvit_mit import *
21 | from .efficientvit_msra import *
22 | from .eva import *
23 | from .fastvit import *
24 | from .focalnet import *
25 | from .gcvit import *
26 | from .ghostnet import *
27 | from .hardcorenas import *
28 | from .hrnet import *
29 | from .inception_next import *
30 | from .inception_resnet_v2 import *
31 | from .inception_v3 import *
32 | from .inception_v4 import *
33 | from .levit import *
34 | from .maxxvit import *
35 | from .metaformer import *
36 | from .mlp_mixer import *
37 | from .mobilenetv3 import *
38 | from .mobilevit import *
39 | from .mvitv2 import *
40 | from .nasnet import *
41 | from .nest import *
42 | from .nfnet import *
43 | from .pit import *
44 | from .pnasnet import *
45 | from .pvt_v2 import *
46 | from .regnet import *
47 | from .repghost import *
48 | from .repvit import *
49 | from .res2net import *
50 | from .resnest import *
51 | from .resnet import *
52 | from .resnetv2 import *
53 | from .rexnet import *
54 | from .selecsls import *
55 | from .senet import *
56 | from .sequencer import *
57 | from .sknet import *
58 | from .swin_transformer import *
59 | from .swin_transformer_v2 import *
60 | from .swin_transformer_v2_cr import *
61 | from .tiny_vit import *
62 | from .tnt import *
63 | from .tresnet import *
64 | from .twins import *
65 | from .vgg import *
66 | from .visformer import *
67 | from .vision_transformer import *
68 | from .vision_transformer_hybrid import *
69 | from .vision_transformer_relpos import *
70 | from .vision_transformer_sam import *
71 | from .volo import *
72 | from .vovnet import *
73 | from .xception import *
74 | from .xception_aligned import *
75 | from .xcit import *
76 |
77 | from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
78 | set_pretrained_download_progress, set_pretrained_check_hash
79 | from ._factory import create_model, parse_model_name, safe_model_name
80 | from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
81 | from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
82 | register_notrace_module, is_notrace_module, get_notrace_modules, \
83 | register_notrace_function, is_notrace_function, get_notrace_functions
84 | from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
85 | from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
86 | from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
87 | group_modules, group_parameters, checkpoint_seq, adapt_input_conv
88 | from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
89 | from ._prune import adapt_model_from_string
90 | from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
91 | register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
92 | is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
93 |
--------------------------------------------------------------------------------
/docs/models/.templates/models/ensemble-adversarial.md:
--------------------------------------------------------------------------------
1 | # # Ensemble Adversarial Inception ResNet v2
2 |
3 | **Inception-ResNet-v2** is a convolutional neural architecture that builds on the Inception family of architectures but incorporates [residual connections](https://paperswithcode.com/method/residual-connection) (replacing the filter concatenation stage of the Inception architecture).
4 |
5 | This particular model was trained for study of adversarial examples (adversarial training).
6 |
7 | The weights from this model were ported from [Tensorflow/Models](https://github.com/tensorflow/models).
8 |
9 | {% include 'code_snippets.md' %}
10 |
11 | ## How do I train this model?
12 |
13 | You can follow the [timm recipe scripts](https://rwightman.github.io/pytorch-image-models/scripts/) for training a new model afresh.
14 |
15 | ## Citation
16 |
17 | ```BibTeX
18 | @article{DBLP:journals/corr/abs-1804-00097,
19 | author = {Alexey Kurakin and
20 | Ian J. Goodfellow and
21 | Samy Bengio and
22 | Yinpeng Dong and
23 | Fangzhou Liao and
24 | Ming Liang and
25 | Tianyu Pang and
26 | Jun Zhu and
27 | Xiaolin Hu and
28 | Cihang Xie and
29 | Jianyu Wang and
30 | Zhishuai Zhang and
31 | Zhou Ren and
32 | Alan L. Yuille and
33 | Sangxia Huang and
34 | Yao Zhao and
35 | Yuzhe Zhao and
36 | Zhonglin Han and
37 | Junjiajia Long and
38 | Yerkebulan Berdibekov and
39 | Takuya Akiba and
40 | Seiya Tokui and
41 | Motoki Abe},
42 | title = {Adversarial Attacks and Defences Competition},
43 | journal = {CoRR},
44 | volume = {abs/1804.00097},
45 | year = {2018},
46 | url = {http://arxiv.org/abs/1804.00097},
47 | archivePrefix = {arXiv},
48 | eprint = {1804.00097},
49 | timestamp = {Thu, 31 Oct 2019 16:31:22 +0100},
50 | biburl = {https://dblp.org/rec/journals/corr/abs-1804-00097.bib},
51 | bibsource = {dblp computer science bibliography, https://dblp.org}
52 | }
53 | ```
54 |
55 |
99 |
--------------------------------------------------------------------------------
/timm/data/readers/reader_image_folder.py:
--------------------------------------------------------------------------------
1 | """ A dataset reader that extracts images from folders
2 |
3 | Folders are scanned recursively to find image files. Labels are based
4 | on the folder hierarchy, just leaf folders by default.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | import os
9 | from typing import Dict, List, Optional, Set, Tuple, Union
10 |
11 | from timm.utils.misc import natural_key
12 |
13 | from .class_map import load_class_map
14 | from .img_extensions import get_img_extensions
15 | from .reader import Reader
16 |
17 |
18 | def find_images_and_targets(
19 | folder: str,
20 | types: Optional[Union[List, Tuple, Set]] = None,
21 | class_to_idx: Optional[Dict] = None,
22 | leaf_name_only: bool = True,
23 | sort: bool = True
24 | ):
25 | """ Walk folder recursively to discover images and map them to classes by folder names.
26 |
27 | Args:
28 | folder: root of folder to recrusively search
29 | types: types (file extensions) to search for in path
30 | class_to_idx: specify mapping for class (folder name) to class index if set
31 | leaf_name_only: use only leaf-name of folder walk for class names
32 | sort: re-sort found images by name (for consistent ordering)
33 |
34 | Returns:
35 | A list of image and target tuples, class_to_idx mapping
36 | """
37 | types = get_img_extensions(as_set=True) if not types else set(types)
38 | labels = []
39 | filenames = []
40 | for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
41 | rel_path = os.path.relpath(root, folder) if (root != folder) else ''
42 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
43 | for f in files:
44 | base, ext = os.path.splitext(f)
45 | if ext.lower() in types:
46 | filenames.append(os.path.join(root, f))
47 | labels.append(label)
48 | if class_to_idx is None:
49 | # building class index
50 | unique_labels = set(labels)
51 | sorted_labels = list(sorted(unique_labels, key=natural_key))
52 | class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
53 | images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
54 | if sort:
55 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
56 | return images_and_targets, class_to_idx
57 |
58 |
59 | class ReaderImageFolder(Reader):
60 |
61 | def __init__(
62 | self,
63 | root,
64 | class_map=''):
65 | super().__init__()
66 |
67 | self.root = root
68 | class_to_idx = None
69 | if class_map:
70 | class_to_idx = load_class_map(class_map, root)
71 | self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
72 | if len(self.samples) == 0:
73 | raise RuntimeError(
74 | f'Found 0 images in subfolders of {root}. '
75 | f'Supported image extensions are {", ".join(get_img_extensions())}')
76 |
77 | def __getitem__(self, index):
78 | path, target = self.samples[index]
79 | return open(path, 'rb'), target
80 |
81 | def __len__(self):
82 | return len(self.samples)
83 |
84 | def _filename(self, index, basename=False, absolute=False):
85 | filename = self.samples[index][0]
86 | if basename:
87 | filename = os.path.basename(filename)
88 | elif not absolute:
89 | filename = os.path.relpath(filename, self.root)
90 | return filename
91 |
--------------------------------------------------------------------------------
/timm/layers/inplace_abn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | try:
5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync
6 | has_iabn = True
7 | except ImportError:
8 | has_iabn = False
9 |
10 | def inplace_abn(x, weight, bias, running_mean, running_var,
11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12 | raise ImportError(
13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14 |
15 | def inplace_abn_sync(**kwargs):
16 | inplace_abn(**kwargs)
17 |
18 |
19 | class InplaceAbn(nn.Module):
20 | """Activated Batch Normalization
21 |
22 | This gathers a BatchNorm and an activation function in a single module
23 |
24 | Parameters
25 | ----------
26 | num_features : int
27 | Number of feature channels in the input and output.
28 | eps : float
29 | Small constant to prevent numerical issues.
30 | momentum : float
31 | Momentum factor applied to compute running statistics.
32 | affine : bool
33 | If `True` apply learned scale and shift transformation after normalization.
34 | act_layer : str or nn.Module type
35 | Name or type of the activation functions, one of: `leaky_relu`, `elu`
36 | act_param : float
37 | Negative slope for the `leaky_relu` activation.
38 | """
39 |
40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None):
42 | super(InplaceAbn, self).__init__()
43 | self.num_features = num_features
44 | self.affine = affine
45 | self.eps = eps
46 | self.momentum = momentum
47 | if apply_act:
48 | if isinstance(act_layer, str):
49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '')
50 | self.act_name = act_layer if act_layer else 'identity'
51 | else:
52 | # convert act layer passed as type to string
53 | if act_layer == nn.ELU:
54 | self.act_name = 'elu'
55 | elif act_layer == nn.LeakyReLU:
56 | self.act_name = 'leaky_relu'
57 | elif act_layer is None or act_layer == nn.Identity:
58 | self.act_name = 'identity'
59 | else:
60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN'
61 | else:
62 | self.act_name = 'identity'
63 | self.act_param = act_param
64 | if self.affine:
65 | self.weight = nn.Parameter(torch.ones(num_features))
66 | self.bias = nn.Parameter(torch.zeros(num_features))
67 | else:
68 | self.register_parameter('weight', None)
69 | self.register_parameter('bias', None)
70 | self.register_buffer('running_mean', torch.zeros(num_features))
71 | self.register_buffer('running_var', torch.ones(num_features))
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | nn.init.constant_(self.running_mean, 0)
76 | nn.init.constant_(self.running_var, 1)
77 | if self.affine:
78 | nn.init.constant_(self.weight, 1)
79 | nn.init.constant_(self.bias, 0)
80 |
81 | def forward(self, x):
82 | output = inplace_abn(
83 | x, self.weight, self.bias, self.running_mean, self.running_var,
84 | self.training, self.momentum, self.eps, self.act_name, self.act_param)
85 | if isinstance(output, tuple):
86 | output = output[0]
87 | return output
88 |
--------------------------------------------------------------------------------