├── LICENSE
├── README.md
├── assets
├── KAT.png
├── kat3-1.png
└── logo.webp
├── dist_train.sh
├── example.py
├── katransformer.py
├── scripts
├── train_kat_base_8x128.sh
├── train_kat_base_8x128_vitft.sh
├── train_kat_small_8x128.sh
├── train_kat_small_8x128_vitft.sh
├── train_kat_tiny_8x128.sh
└── train_kat_tiny_8x128_vitft.sh
├── tools
└── calculate_flops.py
├── train.py
└── validate.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Xingyi Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |  |
5 | Kolmogorov–Arnold Transformer: A PyTorch Implementation |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | ICLR 2025
18 |
19 |
20 |
21 |
22 | Yes, I kan!
23 |
24 |
25 | 🎉 This is a PyTorch/GPU implementation of the paper **Kolmogorov–Arnold Transformer (KAT)**, which replace the MLP layers in transformer with KAN layers.
26 |
27 | For more technical details, please refer to our ICLR'25 paper.
28 |
29 | > **Kolmogorov–Arnold Transformer**
30 | > 📝[[Paper](https://arxiv.org/abs/2409.10594)] >[[code](https://github.com/Adamdad/kat)] >[[Trition/CUDA kernel](https://github.com/Adamdad/rational_kat_cu)]
31 | > [Xingyi Yang](https://adamdad.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
32 | > National University of Singapore
33 | > International Conference on Learning Representations (**ICLR'25**)
34 |
35 | ### 🔑 Key Insight:
36 |
37 | Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.
38 |
39 |
40 |
41 |
42 |
43 | ### 🎯 Our Solutions:
44 | 1. **Base Function**: Replace B-spline to CUDA-implemented Rational.
45 | 2. **Group KAN**: Share weights among groups of edges for efficiency.
46 | 3. **Initialization**: Maintain activation magnitudes across layers.
47 |
48 | ### ✅ Updates
49 | - [x] Release the KAT paper, CUDA implementation and IN-1k training code.
50 | - [x] 🎉🎉🎉🎉 Triton Implementation, on 1D and 2D tasks. This is much easier to install than the CUDA version. Please See [https://github.com/Adamdad/rational_kat_cu](https://github.com/Adamdad/rational_kat_cu).
51 | - [ ] KAT Detection and segmentation code.
52 | - [ ] KAT on NLP tasks.
53 |
54 | ## 🛠️ Installation and Dataset
55 | Please find our CUDA implementation in [https://github.com/Adamdad/rational_kat_cu.git](https://github.com/Adamdad/rational_kat_cu.git).
56 | ```shell
57 | # install torch and other things
58 | pip install timm==1.0.3
59 | pip install wandb # I personally use wandb for results visualizations
60 | git clone https://github.com/Adamdad/rational_kat_cu.git
61 | cd rational_kat_cu
62 | pip install -e .
63 | ```
64 |
65 | 📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)
66 |
67 | ```
68 | │imagenet/
69 | ├──train/
70 | │ ├── n01440764
71 | │ │ ├── n01440764_10026.JPEG
72 | │ │ ├── n01440764_10027.JPEG
73 | │ │ ├── ......
74 | │ ├── ......
75 | ├──val/
76 | │ ├── n01440764
77 | │ │ ├── ILSVRC2012_val_00000293.JPEG
78 | │ │ ├── ILSVRC2012_val_00002138.JPEG
79 | │ │ ├── ......
80 | │ ├── ......
81 | ```
82 |
83 | ## Usage
84 |
85 | Refer to `example.py` for a detailed use case demonstrating how to use KAT with timm to classify an image.
86 |
87 | ## 📊 Model Checkpoints
88 | Download pre-trained models or access training checkpoints:
89 |
90 | |🏷️ Model |⚙️ Setup |📦 Param| 📈 Top1 |🔗 Link|
91 | | ---|---|---| ---|---|
92 | |KAT-T| From Scratch|5.7M | 74.6| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224)
93 | |KAT-T | From ViT | 5.7M | 75.7| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224.vitft)
94 | |KAT-S| From Scratch| 22.1M | 81.2| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224)
95 | |KAT-S | From ViT |22.1M | 82.0| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224.vitft)
96 | | KAT-B| From Scratch |86.6M| 82.3 | [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224)
97 | | KAT-B | From ViT |86.6M| 82.8 | [link](https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224.vitft)|
98 |
99 | ## 🎓Model Training
100 |
101 | All training scripts are under `scripts/`
102 | ```shell
103 | bash scripts/train_kat_tiny_8x128.sh
104 | ```
105 |
106 | If you want to change the hyper-parameters, can edit
107 | ```shell
108 | #!/bin/bash
109 | DATA_PATH=/local_home/dataset/imagenet/
110 |
111 | bash ./dist_train.sh 8 $DATA_PATH \
112 | --model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions
113 | -b 128 \
114 | --opt adamw \
115 | --lr 1e-3 \
116 | --weight-decay 0.05 \
117 | --epochs 300 \
118 | --mixup 0.8 \
119 | --cutmix 1.0 \
120 | --sched cosine \
121 | --smoothing 0.1 \
122 | --drop-path 0.1 \
123 | --aa rand-m9-mstd0.5 \
124 | --remode pixel --reprob 0.25 \
125 | --amp \
126 | --crop-pct 0.875 \
127 | --mean 0.485 0.456 0.406 \
128 | --std 0.229 0.224 0.225 \
129 | --model-ema \
130 | --model-ema-decay 0.9999 \
131 | --output output/kat_tiny_swish_patch16_224 \
132 | --log-wandb
133 | ```
134 |
135 | ## 🧪 Evaluation
136 | To evaluate our `kat_tiny_patch16_224` models, run:
137 |
138 | ```shell
139 | DATA_PATH=/local_home/dataset/imagenet/
140 | CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth
141 | python validate.py $DATA_PATH --model kat_tiny_patch16_224 \
142 | --checkpoint $CHECKPOINT_PATH -b 512
143 |
144 | ###################
145 | Validating in float32. AMP not enabled.
146 | Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'
147 | Model kat_tiny_patch16_224 created, param count: 5718328
148 | Data processing configuration for current model + dataset:
149 | input_size: (3, 224, 224)
150 | interpolation: bicubic
151 | mean: (0.485, 0.456, 0.406)
152 | std: (0.229, 0.224, 0.225)
153 | crop_pct: 0.875
154 | crop_mode: center
155 | Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875)
156 | .......
157 | Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220)
158 | * Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)
159 | --result
160 | {
161 | "model": "kat_tiny_patch16_224",
162 | "top1": 74.558,
163 | "top1_err": 25.442,
164 | "top5": 92.39,
165 | "top5_err": 7.61,
166 | "param_count": 5.72,
167 | "img_size": 224,
168 | "crop_pct": 0.875,
169 | "interpolation": "bicubic"
170 | }
171 | ```
172 |
173 |
174 | ## 🙏 Acknowledgments
175 | We extend our gratitude to the authors of [rational_activations](https://github.com/ml-research/rational_activations) for their contributions to CUDA rational function implementations that inspired parts of this work. We thank [@yuweihao](https://github.com/yuweihao), [@florinshen](https://github.com/florinshen), [@Huage001](https://github.com/Huage001) and [@yu-rp](https://github.com/yu-rp) for valuable discussions.
176 |
177 | ## 📚 Bibtex
178 | If you use this repository, please cite:
179 | ```bibtex
180 | @inproceedings{
181 | yang2025kolmogorovarnold,
182 | title={Kolmogorov-Arnold Transformer},
183 | author={Xingyi Yang, Xinchao Wang},
184 | booktitle={The Thirteenth International Conference on Learning Representations},
185 | year={2025},
186 | url={https://openreview.net/forum?id=BCeock53nt}
187 | }
188 | ```
189 |
--------------------------------------------------------------------------------
/assets/KAT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/KAT.png
--------------------------------------------------------------------------------
/assets/kat3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/kat3-1.png
--------------------------------------------------------------------------------
/assets/logo.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/logo.webp
--------------------------------------------------------------------------------
/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@"
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | from urllib.request import urlopen
2 | from PIL import Image
3 | import timm
4 | import torch
5 | import json
6 | import katransformer
7 |
8 | # Load the image
9 | img = Image.open(urlopen(
10 | 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
11 | ))
12 |
13 | # Move model to CUDA
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 |
16 | # Load the pre-trained KAT model
17 | model = timm.create_model('hf_hub:adamdad/kat_tiny_patch16_224.vitft', pretrained=True)
18 | model = model.to(device)
19 | model = model.eval()
20 |
21 | # Get model-specific transforms (normalization, resize)
22 | data_config = timm.data.resolve_model_data_config(model)
23 | transforms = timm.data.create_transform(**data_config, is_training=False)
24 |
25 | # Preprocess image and make predictions
26 | output = model(transforms(img).unsqueeze(0).to(device)) # unsqueeze single image into batch of 1
27 | top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
28 |
29 | # Load ImageNet class names
30 | imagenet_classes_url = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
31 | class_idx = json.load(urlopen(imagenet_classes_url))
32 |
33 | # Map class indices to class names
34 | top5_class_names = [class_idx[idx] for idx in top5_class_indices[0].tolist()]
35 |
36 | # Print top 5 probabilities and corresponding class names
37 | print("Top-5 Class Names:", top5_class_names)
38 | print("Top-5 Probabilities:", top5_probabilities)
39 |
--------------------------------------------------------------------------------
/katransformer.py:
--------------------------------------------------------------------------------
1 | """ Kolmogorov-Arnold Transformer (KAT) model
2 | author = "Xingyi Yang"
3 | """
4 | import logging
5 | import math
6 | from collections import OrderedDict
7 | from functools import partial
8 | from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
9 | try:
10 | from typing import Literal
11 | except ImportError:
12 | from typing_extensions import Literal
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.utils.checkpoint
18 | from torch.jit import Final
19 |
20 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
21 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
22 | from timm.layers import PatchEmbed, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
23 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
24 | get_act_layer, get_norm_layer, LayerType
25 | from timm.models._builder import build_model_with_cfg
26 | from timm.models._features import feature_take_indices
27 | from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
28 | from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
29 | from timm.models.layers import to_2tuple
30 |
31 | import numpy as np
32 |
33 |
34 | __all__ = ['KAT'] # model_registry will add each entrypoint fn to this
35 |
36 |
37 | _logger = logging.getLogger(__name__)
38 |
39 | import sys
40 | sys.path.insert(0, 'rational_kat_cu')
41 | from kat_rational import KAT_Group
42 |
43 |
44 | def calculate_gain(nonlinearity, param=None):
45 | r"""Return the recommended gain value for the given nonlinearity function.
46 | The values are as follows:
47 |
48 | ================= ====================================================
49 | nonlinearity gain
50 | ================= ====================================================
51 | Linear / Identity :math:`1`
52 | Conv{1,2,3}D :math:`1`
53 | Sigmoid :math:`1`
54 | Tanh :math:`\frac{5}{3}`
55 | ReLU :math:`\sqrt{2}`
56 | Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
57 | SELU :math:`\frac{3}{4}`
58 | ================= ====================================================
59 |
60 | .. warning::
61 | In order to implement `Self-Normalizing Neural Networks`_ ,
62 | you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
63 | This gives the initial weights a variance of ``1 / N``,
64 | which is necessary to induce a stable fixed point in the forward pass.
65 | In contrast, the default gain for ``SELU`` sacrifices the normalisation
66 | effect for more stable gradient flow in rectangular layers.
67 |
68 | Args:
69 | nonlinearity: the non-linear function (`nn.functional` name)
70 | param: optional parameter for the non-linear function
71 |
72 | Examples:
73 | >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
74 |
75 | .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
76 | """
77 | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
78 | if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
79 | return 1
80 | elif nonlinearity == 'tanh':
81 | return 5.0 / 3
82 | elif nonlinearity == 'relu':
83 | return math.sqrt(2.0)
84 | elif nonlinearity == 'gelu':
85 | return math.sqrt(2.3567850379928976)
86 | elif nonlinearity == 'silu' or nonlinearity == 'swish':
87 | return math.sqrt(2.8178205270359653)
88 | elif nonlinearity == 'leaky_relu':
89 | if param is None:
90 | negative_slope = 0.01
91 | elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
92 | # True/False are instances of int, hence check above
93 | negative_slope = param
94 | else:
95 | raise ValueError("negative_slope {} not a valid number".format(param))
96 | return math.sqrt(2.0 / (1 + negative_slope ** 2))
97 | elif nonlinearity == 'selu':
98 | return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
99 | else:
100 | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
101 |
102 | torch.nn.init.calculate_gain = calculate_gain
103 |
104 |
105 | class KAN(nn.Module):
106 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
107 | """
108 | def __init__(
109 | self,
110 | in_features,
111 | hidden_features=None,
112 | out_features=None,
113 | act_layer=KAT_Group,
114 | norm_layer=None,
115 | bias=True,
116 | drop=0.,
117 | use_conv=False,
118 | act_init="gelu",
119 | device=None
120 | ):
121 | super().__init__()
122 | if device is None:
123 | device = "cuda" if torch.cuda.is_available() else "cpu"
124 | out_features = out_features or in_features
125 | hidden_features = hidden_features or in_features
126 | bias = to_2tuple(bias)
127 | drop_probs = to_2tuple(drop)
128 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
129 |
130 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
131 | self.act1 = KAT_Group(mode="identity", device=device)
132 | self.drop1 = nn.Dropout(drop_probs[0])
133 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
134 | self.act2 = KAT_Group(mode=act_init, device=device)
135 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136 | self.drop2 = nn.Dropout(drop_probs[1])
137 |
138 | def forward(self, x):
139 | x = self.act1(x)
140 | x = self.drop1(x)
141 | x = self.fc1(x)
142 | x = self.act2(x)
143 | x = self.drop2(x)
144 | x = self.fc2(x)
145 | return x
146 |
147 | def get_ortho_like(dim, alpha, beta, sign=1, dist='uniform'):
148 | """
149 | Generate an orthogonal-like matrix with specified dimensions and properties.
150 |
151 | Args:
152 | dim (int): The dimension of the matrix.
153 | alpha (float): The scaling factor for the random matrix.
154 | beta (float): The scaling factor for the identity matrix.
155 | sign (int, optional): The sign of the identity matrix. Defaults to 1.
156 | dist (str, optional): The distribution of the random matrix.
157 | Can be either 'normal' or 'uniform'. Defaults to 'uniform'.
158 |
159 | Returns:
160 | tuple: A tuple containing the left and right orthogonal-like matrices.
161 | """
162 | if dist == 'normal':
163 | A = alpha * np.random.normal(size=(dim,dim)) / (dim**0.5) + sign * beta * np.eye(dim)
164 | if dist == 'uniform':
165 | A = alpha * np.random.uniform(size=(dim,dim), low=-3**0.5 / (dim**0.5), high = 3**0.5 / (dim**0.5)) + sign * beta * np.eye(dim)
166 |
167 | U, S, V = np.linalg.svd(A)
168 | L = U @ np.diag(np.sqrt(S))
169 | R = np.diag(np.sqrt(S)) @ V
170 | return L, R
171 |
172 |
173 | def get_ortho_like_gaussian2d(dim, alpha, beta, sigma=1.0, sign=1, dist='uniform'):
174 | if dist == 'normal':
175 | A = alpha * np.random.normal(size=(dim, dim)) / (dim**0.5)
176 | elif dist == 'uniform':
177 | A = alpha * np.random.uniform(low=-3**0.5 / (dim**0.5), high=3**0.5 / (dim**0.5), size=(dim, dim))
178 |
179 | size = int(np.sqrt(dim))
180 | gauss_diag = beta * gaussian_correlation_matrix((size, size), sigma=sigma)
181 | gauss_diag = sign * gauss_diag
182 |
183 | A += gauss_diag
184 |
185 | # Compute the SVD of the updated matrix A
186 | U, S, V = np.linalg.svd(A)
187 |
188 | # Create the left and right matrices
189 | L = U @ np.diag(np.sqrt(S))
190 | R = np.diag(np.sqrt(S)) @ V
191 |
192 | return L, R
193 |
194 |
195 | def gaussian_correlation_matrix(image_shape, sigma=1.0):
196 | """
197 | Generate a Gaussian correlation matrix for a flattened 2D image.
198 |
199 | Args:
200 | image_shape (tuple): The shape of the image (height, width).
201 | sigma (float): The standard deviation of the Gaussian distribution.
202 |
203 | Returns:
204 | np.array: A 2D array where each entry represents the correlation
205 | between two flattened image pixels.
206 | """
207 | # Create an array of all pixel coordinates
208 | y, x = np.indices(image_shape)
209 |
210 | # Flatten the coordinate arrays
211 | y = y.ravel()
212 | x = x.ravel()
213 |
214 | # Compute distances between every pair of pixels
215 | distance_matrix = np.sqrt((y[:, np.newaxis] - y[np.newaxis, :])**2 +
216 | (x[:, np.newaxis] - x[np.newaxis, :])**2)
217 |
218 | # Compute the Gaussian correlation matrix using the distance matrix
219 | correlation_matrix = np.exp(-distance_matrix**2 / (2 * sigma**2))
220 |
221 | return correlation_matrix
222 |
223 | class Attention(nn.Module):
224 | fused_attn: Final[bool]
225 |
226 | def __init__(
227 | self,
228 | dim: int,
229 | num_heads: int = 8,
230 | qkv_bias: bool = False,
231 | qk_norm: bool = False,
232 | attn_drop: float = 0.,
233 | proj_drop: float = 0.,
234 | norm_layer: nn.Module = nn.LayerNorm,
235 | ) -> None:
236 | super().__init__()
237 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
238 | self.num_heads = num_heads
239 | self.head_dim = dim // num_heads
240 | self.scale = self.head_dim ** -0.5
241 | self.fused_attn = use_fused_attn()
242 |
243 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
244 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
245 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
246 | self.attn_drop = nn.Dropout(attn_drop)
247 | self.proj = nn.Linear(dim, dim)
248 | self.proj_drop = nn.Dropout(proj_drop)
249 |
250 | def forward(self, x: torch.Tensor) -> torch.Tensor:
251 | B, N, C = x.shape
252 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
253 | q, k, v = qkv.unbind(0)
254 | q, k = self.q_norm(q), self.k_norm(k)
255 |
256 | if self.fused_attn:
257 | x = F.scaled_dot_product_attention(
258 | q, k, v,
259 | dropout_p=self.attn_drop.p if self.training else 0.,
260 | )
261 | else:
262 | q = q * self.scale
263 | attn = q @ k.transpose(-2, -1)
264 | attn = attn.softmax(dim=-1)
265 | attn = self.attn_drop(attn)
266 | x = attn @ v
267 |
268 | x = x.transpose(1, 2).reshape(B, N, C)
269 | x = self.proj(x)
270 | x = self.proj_drop(x)
271 | return x
272 |
273 | class LayerScale(nn.Module):
274 | def __init__(
275 | self,
276 | dim: int,
277 | init_values: float = 1e-5,
278 | inplace: bool = False,
279 | ) -> None:
280 | super().__init__()
281 | self.inplace = inplace
282 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
283 |
284 | def forward(self, x: torch.Tensor) -> torch.Tensor:
285 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
286 |
287 |
288 | class Block(nn.Module):
289 | def __init__(
290 | self,
291 | dim: int,
292 | num_heads: int,
293 | mlp_ratio: float = 4.,
294 | qkv_bias: bool = False,
295 | qk_norm: bool = False,
296 | proj_drop: float = 0.,
297 | attn_drop: float = 0.,
298 | init_values: Optional[float] = None,
299 | drop_path: float = 0.,
300 | act_layer: nn.Module = nn.GELU,
301 | norm_layer: nn.Module = nn.LayerNorm,
302 | mlp_layer: nn.Module = KAN,
303 | act_init: str = 'gelu',
304 | ) -> None:
305 | super().__init__()
306 | self.norm1 = norm_layer(dim)
307 | self.attn = Attention(
308 | dim,
309 | num_heads=num_heads,
310 | qkv_bias=qkv_bias,
311 | qk_norm=qk_norm,
312 | attn_drop=attn_drop,
313 | proj_drop=proj_drop,
314 | norm_layer=norm_layer,
315 | )
316 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
317 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
318 |
319 | self.norm2 = norm_layer(dim)
320 | self.mlp = mlp_layer(
321 | in_features=dim,
322 | hidden_features=int(dim * mlp_ratio),
323 | act_layer=act_layer,
324 | drop=proj_drop,
325 | act_init=act_init,
326 | )
327 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
328 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
329 |
330 | def forward(self, x: torch.Tensor) -> torch.Tensor:
331 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
332 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
333 | return x
334 |
335 |
336 |
337 | class KATVisionTransformer(nn.Module):
338 | """ Vision Transformer
339 |
340 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
341 | - https://arxiv.org/abs/2010.11929
342 | """
343 | dynamic_img_size: Final[bool]
344 |
345 | def __init__(
346 | self,
347 | img_size: Union[int, Tuple[int, int]] = 224,
348 | patch_size: Union[int, Tuple[int, int]] = 16,
349 | in_chans: int = 3,
350 | num_classes: int = 1000,
351 | global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
352 | embed_dim: int = 768,
353 | depth: int = 12,
354 | num_heads: int = 12,
355 | mlp_ratio: float = 4.,
356 | qkv_bias: bool = True,
357 | qk_norm: bool = False,
358 | init_values: Optional[float] = None,
359 | class_token: bool = True,
360 | pos_embed: str = 'learn',
361 | no_embed_class: bool = False,
362 | reg_tokens: int = 0,
363 | pre_norm: bool = False,
364 | fc_norm: Optional[bool] = None,
365 | dynamic_img_size: bool = False,
366 | dynamic_img_pad: bool = False,
367 | drop_rate: float = 0.,
368 | pos_drop_rate: float = 0.,
369 | patch_drop_rate: float = 0.,
370 | proj_drop_rate: float = 0.,
371 | attn_drop_rate: float = 0.,
372 | drop_path_rate: float = 0.,
373 | weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', '', 'kan', 'kan_mimetic', 'mimetic', 'kan_mimetic2d'] = '',
374 | fix_init: bool = False,
375 | embed_layer: Callable = PatchEmbed,
376 | norm_layer: Optional[LayerType] = None,
377 | act_layer: Optional[LayerType] = None, # None,
378 | block_fn: Type[nn.Module] = Block,
379 | mlp_layer: Type[nn.Module] = KAN,
380 | act_init: str = 'gelu',
381 | ) -> None:
382 | """
383 | Args:
384 | img_size: Input image size.
385 | patch_size: Patch size.
386 | in_chans: Number of image input channels.
387 | num_classes: Mumber of classes for classification head.
388 | global_pool: Type of global pooling for final sequence (default: 'token').
389 | embed_dim: Transformer embedding dimension.
390 | depth: Depth of transformer.
391 | num_heads: Number of attention heads.
392 | mlp_ratio: Ratio of mlp hidden dim to embedding dim.
393 | qkv_bias: Enable bias for qkv projections if True.
394 | init_values: Layer-scale init values (layer-scale enabled if not None).
395 | class_token: Use class token.
396 | no_embed_class: Don't include position embeddings for class (or reg) tokens.
397 | reg_tokens: Number of register tokens.
398 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
399 | drop_rate: Head dropout rate.
400 | pos_drop_rate: Position embedding dropout rate.
401 | attn_drop_rate: Attention dropout rate.
402 | drop_path_rate: Stochastic depth rate.
403 | weight_init: Weight initialization scheme.
404 | fix_init: Apply weight initialization fix (scaling w/ layer index).
405 | embed_layer: Patch embedding layer.
406 | norm_layer: Normalization layer.
407 | act_layer: MLP activation layer.
408 | block_fn: Transformer block layer.
409 | """
410 | super().__init__()
411 | assert global_pool in ('', 'avg', 'token', 'map')
412 | assert class_token or global_pool != 'token'
413 | assert pos_embed in ('', 'none', 'learn')
414 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
415 | norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
416 | act_layer = get_act_layer(act_layer) or nn.GELU
417 |
418 | self.num_classes = num_classes
419 | self.global_pool = global_pool
420 | self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
421 | self.num_prefix_tokens = 1 if class_token else 0
422 | self.num_prefix_tokens += reg_tokens
423 | self.num_reg_tokens = reg_tokens
424 | self.has_class_token = class_token
425 | self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
426 | self.dynamic_img_size = dynamic_img_size
427 | self.grad_checkpointing = False
428 | self.act_init = act_init
429 |
430 | embed_args = {}
431 | if dynamic_img_size:
432 | # flatten deferred until after pos embed
433 | embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
434 | self.patch_embed = embed_layer(
435 | img_size=img_size,
436 | patch_size=patch_size,
437 | in_chans=in_chans,
438 | embed_dim=embed_dim,
439 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
440 | dynamic_img_pad=dynamic_img_pad,
441 | **embed_args,
442 | )
443 | num_patches = self.patch_embed.num_patches
444 | reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
445 |
446 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
447 | self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
448 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
449 | if not pos_embed or pos_embed == 'none':
450 | self.pos_embed = None
451 | else:
452 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
453 | self.pos_drop = nn.Dropout(p=pos_drop_rate)
454 | if patch_drop_rate > 0:
455 | self.patch_drop = PatchDropout(
456 | patch_drop_rate,
457 | num_prefix_tokens=self.num_prefix_tokens,
458 | )
459 | else:
460 | self.patch_drop = nn.Identity()
461 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
462 |
463 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
464 | self.blocks = nn.Sequential(*[
465 | block_fn(
466 | dim=embed_dim,
467 | num_heads=num_heads,
468 | mlp_ratio=mlp_ratio,
469 | qkv_bias=qkv_bias,
470 | qk_norm=qk_norm,
471 | init_values=init_values,
472 | proj_drop=proj_drop_rate,
473 | attn_drop=attn_drop_rate,
474 | drop_path=dpr[i],
475 | norm_layer=norm_layer,
476 | act_layer=act_layer,
477 | mlp_layer=mlp_layer,
478 | act_init=act_init,
479 | )
480 | for i in range(depth)])
481 | self.feature_info = [
482 | dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
483 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
484 |
485 | # Classifier Head
486 | if global_pool == 'map':
487 | self.attn_pool = AttentionPoolLatent(
488 | self.embed_dim,
489 | num_heads=num_heads,
490 | mlp_ratio=mlp_ratio,
491 | norm_layer=norm_layer,
492 | )
493 | else:
494 | self.attn_pool = None
495 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
496 | self.head_drop = nn.Dropout(drop_rate)
497 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
498 |
499 | if weight_init != 'skip':
500 | self.init_weights(weight_init)
501 | if fix_init:
502 | self.fix_init_weight()
503 |
504 | def fix_init_weight(self):
505 | def rescale(param, _layer_id):
506 | param.div_(math.sqrt(2.0 * _layer_id))
507 |
508 | for layer_id, layer in enumerate(self.blocks):
509 | rescale(layer.attn.proj.weight.data, layer_id + 1)
510 | rescale(layer.mlp.fc2.weight.data, layer_id + 1)
511 |
512 | def init_weights(self, mode: str = '') -> None:
513 | assert mode in ('jax', 'jax_nlhb', 'moco', '', 'kan', 'kan_mimetic', 'mimetic')
514 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
515 | if self.pos_embed is not None:
516 | trunc_normal_(self.pos_embed, std=.02)
517 | if self.cls_token is not None:
518 | nn.init.normal_(self.cls_token, std=1e-6)
519 |
520 | mode = mode + f'_{self.act_init}' if 'kan' in mode else mode
521 | named_apply(get_init_weights_vit(mode, head_bias), self)
522 |
523 | if 'mimetic' in mode:
524 | named_apply(init_weights_attn_mimetic, self)
525 | # named_apply(init_weights_attn_mimetic, self)
526 |
527 | def _init_weights(self, m: nn.Module) -> None:
528 | # this fn left here for compat with downstream users
529 | init_weights_vit_timm(m)
530 |
531 | @torch.jit.ignore()
532 | def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
533 | if checkpoint_path.endswith('.npz'):
534 | # load from .npz checkpoint
535 | _load_weights(self, checkpoint_path, prefix)
536 | else:
537 | print(f'Loading model weights from: {checkpoint_path}')
538 | # load from .pth checkpoint
539 | state_dict = torch.load(checkpoint_path, map_location='cpu')['model']
540 | if 'state_dict' in state_dict:
541 | state_dict = state_dict['state_dict']
542 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
543 | state_dict['pos_embed'] = state_dict['pos_embed'][:, 1:]
544 | msg = self.load_state_dict(state_dict, strict=False)
545 | print(msg)
546 |
547 | @torch.jit.ignore
548 | def no_weight_decay(self) -> Set:
549 | return {'pos_embed', 'cls_token', 'dist_token'}
550 |
551 | @torch.jit.ignore
552 | def group_matcher(self, coarse: bool = False) -> Dict:
553 | return dict(
554 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
555 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
556 | )
557 |
558 | @torch.jit.ignore
559 | def set_grad_checkpointing(self, enable: bool = True) -> None:
560 | self.grad_checkpointing = enable
561 | if hasattr(self.patch_embed, 'set_grad_checkpointing'):
562 | self.patch_embed.set_grad_checkpointing(enable)
563 |
564 | @torch.jit.ignore
565 | def get_classifier(self) -> nn.Module:
566 | return self.head
567 |
568 | def reset_classifier(self, num_classes: int, global_pool = None) -> None:
569 | self.num_classes = num_classes
570 | if global_pool is not None:
571 | assert global_pool in ('', 'avg', 'token', 'map')
572 | if global_pool == 'map' and self.attn_pool is None:
573 | assert False, "Cannot currently add attention pooling in reset_classifier()."
574 | elif global_pool != 'map ' and self.attn_pool is not None:
575 | self.attn_pool = None # remove attention pooling
576 | self.global_pool = global_pool
577 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
578 |
579 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
580 | if self.pos_embed is None:
581 | return x.view(x.shape[0], -1, x.shape[-1])
582 |
583 | if self.dynamic_img_size:
584 | B, H, W, C = x.shape
585 | pos_embed = resample_abs_pos_embed(
586 | self.pos_embed,
587 | (H, W),
588 | num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
589 | )
590 | x = x.view(B, -1, C)
591 | else:
592 | pos_embed = self.pos_embed
593 |
594 | to_cat = []
595 | if self.cls_token is not None:
596 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
597 | if self.reg_token is not None:
598 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
599 |
600 | if self.no_embed_class:
601 | # deit-3, updated JAX (big vision)
602 | # position embedding does not overlap with class token, add then concat
603 | x = x + pos_embed
604 | if to_cat:
605 | x = torch.cat(to_cat + [x], dim=1)
606 | else:
607 | # original timm, JAX, and deit vit impl
608 | # pos_embed has entry for class token, concat then add
609 | if to_cat:
610 | x = torch.cat(to_cat + [x], dim=1)
611 | x = x + pos_embed
612 |
613 | return self.pos_drop(x)
614 |
615 | def forward_intermediates(
616 | self,
617 | x: torch.Tensor,
618 | indices: Optional[Union[int, List[int], Tuple[int]]] = None,
619 | return_prefix_tokens: bool = False,
620 | norm: bool = False,
621 | stop_early: bool = False,
622 | output_fmt: str = 'NCHW',
623 | intermediates_only: bool = False,
624 | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
625 | """ Forward features that returns intermediates.
626 |
627 | Args:
628 | x: Input image tensor
629 | indices: Take last n blocks if int, all if None, select matching indices if sequence
630 | return_prefix_tokens: Return both prefix and spatial intermediate tokens
631 | norm: Apply norm layer to all intermediates
632 | stop_early: Stop iterating over blocks when last desired intermediate hit
633 | output_fmt: Shape of intermediate feature outputs
634 | intermediates_only: Only return intermediate features
635 | Returns:
636 |
637 | """
638 | assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
639 | reshape = output_fmt == 'NCHW'
640 | intermediates = []
641 | take_indices, max_index = feature_take_indices(len(self.blocks), indices)
642 |
643 | # forward pass
644 | B, _, height, width = x.shape
645 | x = self.patch_embed(x)
646 | x = self._pos_embed(x)
647 | x = self.patch_drop(x)
648 | x = self.norm_pre(x)
649 |
650 | if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
651 | blocks = self.blocks
652 | else:
653 | blocks = self.blocks[:max_index + 1]
654 | for i, blk in enumerate(blocks):
655 | x = blk(x)
656 | if i in take_indices:
657 | # normalize intermediates with final norm layer if enabled
658 | intermediates.append(self.norm(x) if norm else x)
659 |
660 | # process intermediates
661 | if self.num_prefix_tokens:
662 | # split prefix (e.g. class, distill) and spatial feature tokens
663 | prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
664 | intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
665 | if reshape:
666 | # reshape to BCHW output format
667 | H, W = self.patch_embed.dynamic_feat_size((height, width))
668 | intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
669 | if not torch.jit.is_scripting() and return_prefix_tokens:
670 | # return_prefix not support in torchscript due to poor type handling
671 | intermediates = list(zip(intermediates, prefix_tokens))
672 |
673 | if intermediates_only:
674 | return intermediates
675 |
676 | x = self.norm(x)
677 |
678 | return x, intermediates
679 |
680 | def prune_intermediate_layers(
681 | self,
682 | indices: Union[int, List[int], Tuple[int]] = 1,
683 | prune_norm: bool = False,
684 | prune_head: bool = True,
685 | ):
686 | """ Prune layers not required for specified intermediates.
687 | """
688 | take_indices, max_index = feature_take_indices(len(self.blocks), indices)
689 | self.blocks = self.blocks[:max_index + 1] # truncate blocks
690 | if prune_norm:
691 | self.norm = nn.Identity()
692 | if prune_head:
693 | self.fc_norm = nn.Identity()
694 | self.reset_classifier(0, '')
695 | return take_indices
696 |
697 | def get_intermediate_layers(
698 | self,
699 | x: torch.Tensor,
700 | n: Union[int, List[int], Tuple[int]] = 1,
701 | reshape: bool = False,
702 | return_prefix_tokens: bool = False,
703 | norm: bool = False,
704 | ) -> List[torch.Tensor]:
705 | """ Intermediate layer accessor inspired by DINO / DINOv2 interface.
706 | NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
707 | """
708 | return self.forward_intermediates(
709 | x, n,
710 | return_prefix_tokens=return_prefix_tokens,
711 | norm=norm,
712 | output_fmt='NCHW' if reshape else 'NLC',
713 | intermediates_only=True,
714 | )
715 |
716 | def forward_features(self, x: torch.Tensor) -> torch.Tensor:
717 | x = self.patch_embed(x)
718 | x = self._pos_embed(x)
719 | x = self.patch_drop(x)
720 | x = self.norm_pre(x)
721 | if self.grad_checkpointing and not torch.jit.is_scripting():
722 | x = checkpoint_seq(self.blocks, x)
723 | else:
724 | x = self.blocks(x)
725 | x = self.norm(x)
726 | return x
727 |
728 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
729 | if self.attn_pool is not None:
730 | x = self.attn_pool(x)
731 | elif self.global_pool == 'avg':
732 | x = x[:, self.num_prefix_tokens:].mean(dim=1)
733 | elif self.global_pool:
734 | x = x[:, 0] # class token
735 | x = self.fc_norm(x)
736 | x = self.head_drop(x)
737 | return x if pre_logits else self.head(x)
738 |
739 | def forward(self, x: torch.Tensor) -> torch.Tensor:
740 | x = self.forward_features(x)
741 | x = self.forward_head(x)
742 | return x
743 |
744 |
745 | def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
746 | """ ViT weight initialization, original timm impl (for reproducibility) """
747 | if isinstance(module, nn.Linear):
748 | trunc_normal_(module.weight, std=.02)
749 | if module.bias is not None:
750 | nn.init.zeros_(module.bias)
751 | elif hasattr(module, 'init_weights'):
752 | module.init_weights()
753 |
754 |
755 | def init_weights_vit_kan_gelu(module: nn.Module, name: str = '') -> None:
756 | """ ViT weight initialization, original timm impl (for reproducibility) """
757 | if isinstance(module, nn.Linear):
758 | if 'mlp' in name:
759 | if 'fc1' in name:
760 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='linear')
761 | elif 'fc2' in name:
762 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='gelu')
763 |
764 | # nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
765 |
766 | if 'qkv' in name:
767 | # treat the weights of Q, K, V separately
768 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
769 | nn.init.uniform_(module.weight, -val, val)
770 | else:
771 | trunc_normal_(module.weight, std=.02)
772 | if module.bias is not None:
773 | nn.init.zeros_(module.bias)
774 | elif hasattr(module, 'init_weights'):
775 | module.init_weights()
776 |
777 | def init_weights_vit_kan_swish(module: nn.Module, name: str = '') -> None:
778 | """ ViT weight initialization, original timm impl (for reproducibility) """
779 | if isinstance(module, nn.Linear):
780 | if 'mlp' in name:
781 | if 'fc1' in name:
782 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='linear')
783 | elif 'fc2' in name:
784 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='swish')
785 | if 'qkv' in name:
786 | # treat the weights of Q, K, V separately
787 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
788 | nn.init.uniform_(module.weight, -val, val)
789 | else:
790 | trunc_normal_(module.weight, std=.02)
791 | if module.bias is not None:
792 | nn.init.zeros_(module.bias)
793 | elif hasattr(module, 'init_weights'):
794 | module.init_weights()
795 |
796 | def init_weights_attn_mimetic(module: nn.Module, name: str = '') -> None:
797 | """ ViT weight initialization, original timm impl (for reproducibility) """
798 | if isinstance(module, Attention):
799 | alpha1 = 0.7
800 | beta1 = 0.7
801 | alpha2 = 0.4
802 | beta2 = 0.4
803 | head_dim = module.head_dim
804 | embed_dim = module.head_dim * module.num_heads
805 |
806 | for h in range(module.num_heads):
807 | Q, K = get_ortho_like(embed_dim, alpha1, beta1, 1)
808 | Q = Q[:,:head_dim]
809 | K = K.T[:,:head_dim]
810 |
811 | module.qkv.weight.data[(h*head_dim):((h+1)*head_dim)] = torch.tensor(Q.T).float()
812 | module.qkv.weight.data[embed_dim+(h*head_dim):embed_dim+((h+1)*head_dim)] = torch.tensor(K.T).float()
813 |
814 | V, Proj = get_ortho_like(embed_dim, alpha2, beta2, -1)
815 | module.qkv.weight.data[2*embed_dim:] = torch.tensor(V).float()
816 | module.proj.weight.data = torch.tensor(Proj).float()
817 |
818 |
819 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None:
820 | """ ViT weight initialization, matching JAX (Flax) impl """
821 | if isinstance(module, nn.Linear):
822 | if name.startswith('head'):
823 | nn.init.zeros_(module.weight)
824 | nn.init.constant_(module.bias, head_bias)
825 | else:
826 | nn.init.xavier_uniform_(module.weight)
827 | if module.bias is not None:
828 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
829 | elif isinstance(module, nn.Conv2d):
830 | lecun_normal_(module.weight)
831 | if module.bias is not None:
832 | nn.init.zeros_(module.bias)
833 | elif hasattr(module, 'init_weights'):
834 | module.init_weights()
835 |
836 |
837 | def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
838 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
839 | if isinstance(module, nn.Linear):
840 | if 'qkv' in name:
841 | # treat the weights of Q, K, V separately
842 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
843 | nn.init.uniform_(module.weight, -val, val)
844 | else:
845 | nn.init.xavier_uniform_(module.weight)
846 | if module.bias is not None:
847 | nn.init.zeros_(module.bias)
848 | elif hasattr(module, 'init_weights'):
849 | module.init_weights()
850 |
851 |
852 | def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
853 | if 'jax' in mode:
854 | return partial(init_weights_vit_jax, head_bias=head_bias)
855 | elif 'moco' in mode:
856 | return init_weights_vit_moco
857 | elif 'kan' in mode:
858 | if 'gelu' in mode:
859 | return init_weights_vit_kan_gelu
860 | elif 'swish' in mode:
861 | return init_weights_vit_kan_swish
862 | else:
863 | AssertionError(f'Unknown mode {mode}')
864 | else:
865 | return init_weights_vit_timm
866 |
867 |
868 | def resize_pos_embed(
869 | posemb: torch.Tensor,
870 | posemb_new: torch.Tensor,
871 | num_prefix_tokens: int = 1,
872 | gs_new: Tuple[int, int] = (),
873 | interpolation: str = 'bicubic',
874 | antialias: bool = False,
875 | ) -> torch.Tensor:
876 | """ Rescale the grid of position embeddings when loading from state_dict.
877 | *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
878 | """
879 | ntok_new = posemb_new.shape[1] - num_prefix_tokens
880 | ntok_old = posemb.shape[1] - num_prefix_tokens
881 | gs_old = [int(math.sqrt(ntok_old))] * 2
882 | if not len(gs_new): # backwards compatibility
883 | gs_new = [int(math.sqrt(ntok_new))] * 2
884 | return resample_abs_pos_embed(
885 | posemb, gs_new, gs_old,
886 | num_prefix_tokens=num_prefix_tokens,
887 | interpolation=interpolation,
888 | antialias=antialias,
889 | verbose=True,
890 | )
891 |
892 |
893 | @torch.no_grad()
894 | def _load_weights(model: KATVisionTransformer, checkpoint_path: str, prefix: str = '') -> None:
895 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
896 | """
897 | import numpy as np
898 |
899 | def _n2p(w, t=True, idx=None):
900 | if idx is not None:
901 | w = w[idx]
902 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
903 | w = w.flatten()
904 | if t:
905 | if w.ndim == 4:
906 | w = w.transpose([3, 2, 0, 1])
907 | elif w.ndim == 3:
908 | w = w.transpose([2, 0, 1])
909 | elif w.ndim == 2:
910 | w = w.transpose([1, 0])
911 | return torch.from_numpy(w)
912 |
913 | w = np.load(checkpoint_path)
914 | interpolation = 'bilinear'
915 | antialias = False
916 | big_vision = False
917 | if not prefix:
918 | if 'opt/target/embedding/kernel' in w:
919 | prefix = 'opt/target/'
920 | elif 'params/embedding/kernel' in w:
921 | prefix = 'params/'
922 | big_vision = True
923 | elif 'params/img/embedding/kernel' in w:
924 | prefix = 'params/img/'
925 | big_vision = True
926 |
927 | if hasattr(model.patch_embed, 'backbone'):
928 | # hybrid
929 | backbone = model.patch_embed.backbone
930 | stem_only = not hasattr(backbone, 'stem')
931 | stem = backbone if stem_only else backbone.stem
932 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
933 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
934 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
935 | if not stem_only:
936 | for i, stage in enumerate(backbone.stages):
937 | for j, block in enumerate(stage.blocks):
938 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
939 | for r in range(3):
940 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
941 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
942 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
943 | if block.downsample is not None:
944 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
945 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
946 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
947 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
948 | else:
949 | embed_conv_w = adapt_input_conv(
950 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
951 | if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
952 | embed_conv_w = resample_patch_embed(
953 | embed_conv_w,
954 | model.patch_embed.proj.weight.shape[-2:],
955 | interpolation=interpolation,
956 | antialias=antialias,
957 | verbose=True,
958 | )
959 |
960 | model.patch_embed.proj.weight.copy_(embed_conv_w)
961 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
962 | if model.cls_token is not None:
963 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
964 | if big_vision:
965 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
966 | else:
967 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
968 | if pos_embed_w.shape != model.pos_embed.shape:
969 | old_shape = pos_embed_w.shape
970 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
971 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
972 | pos_embed_w,
973 | new_size=model.patch_embed.grid_size,
974 | num_prefix_tokens=num_prefix_tokens,
975 | interpolation=interpolation,
976 | antialias=antialias,
977 | verbose=True,
978 | )
979 | model.pos_embed.copy_(pos_embed_w)
980 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
981 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
982 | if (isinstance(model.head, nn.Linear) and
983 | f'{prefix}head/bias' in w and
984 | model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
985 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
986 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
987 | # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
988 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
989 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
990 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
991 | if model.attn_pool is not None:
992 | block_prefix = f'{prefix}MAPHead_0/'
993 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
994 | model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
995 | model.attn_pool.kv.weight.copy_(torch.cat([
996 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
997 | model.attn_pool.kv.bias.copy_(torch.cat([
998 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
999 | model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
1000 | model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
1001 | model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
1002 | model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
1003 | model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
1004 | model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
1005 | for r in range(2):
1006 | getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
1007 | getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
1008 |
1009 | mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
1010 | for i, block in enumerate(model.blocks.children()):
1011 | if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
1012 | block_prefix = f'{prefix}Transformer/encoderblock/'
1013 | idx = i
1014 | else:
1015 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
1016 | idx = None
1017 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
1018 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
1019 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
1020 | block.attn.qkv.weight.copy_(torch.cat([
1021 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
1022 | block.attn.qkv.bias.copy_(torch.cat([
1023 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
1024 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
1025 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
1026 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
1027 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
1028 | for r in range(2):
1029 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(
1030 | _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
1031 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(
1032 | _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
1033 |
1034 |
1035 | def _convert_openai_clip(
1036 | state_dict: Dict[str, torch.Tensor],
1037 | model: KATVisionTransformer,
1038 | prefix: str = 'visual.',
1039 | ) -> Dict[str, torch.Tensor]:
1040 | out_dict = {}
1041 | swaps = [
1042 | ('conv1', 'patch_embed.proj'),
1043 | ('positional_embedding', 'pos_embed'),
1044 | ('transformer.resblocks.', 'blocks.'),
1045 | ('ln_pre', 'norm_pre'),
1046 | ('ln_post', 'norm'),
1047 | ('ln_', 'norm'),
1048 | ('in_proj_', 'qkv.'),
1049 | ('out_proj', 'proj'),
1050 | ('mlp.c_fc', 'mlp.fc1'),
1051 | ('mlp.c_proj', 'mlp.fc2'),
1052 | ]
1053 | for k, v in state_dict.items():
1054 | if not k.startswith(prefix):
1055 | continue
1056 | k = k.replace(prefix, '')
1057 | for sp in swaps:
1058 | k = k.replace(sp[0], sp[1])
1059 |
1060 | if k == 'proj':
1061 | k = 'head.weight'
1062 | v = v.transpose(0, 1)
1063 | out_dict['head.bias'] = torch.zeros(v.shape[0])
1064 | elif k == 'class_embedding':
1065 | k = 'cls_token'
1066 | v = v.unsqueeze(0).unsqueeze(1)
1067 | elif k == 'pos_embed':
1068 | v = v.unsqueeze(0)
1069 | out_dict[k] = v
1070 | return out_dict
1071 |
1072 |
1073 | def _convert_dinov2(
1074 | state_dict: Dict[str, torch.Tensor],
1075 | model: KATVisionTransformer,
1076 | ) -> Dict[str, torch.Tensor]:
1077 | import re
1078 | out_dict = {}
1079 | state_dict.pop("mask_token", None)
1080 | if 'register_tokens' in state_dict:
1081 | # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
1082 | out_dict['reg_token'] = state_dict.pop('register_tokens')
1083 | out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
1084 | out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
1085 | for k, v in state_dict.items():
1086 | if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1087 | out_dict[k.replace("w12", "fc1")] = v
1088 | continue
1089 | elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1090 | out_dict[k.replace("w3", "fc2")] = v
1091 | continue
1092 | out_dict[k] = v
1093 | return out_dict
1094 |
1095 |
1096 | def checkpoint_filter_fn(
1097 | state_dict: Dict[str, torch.Tensor],
1098 | model: KATVisionTransformer,
1099 | adapt_layer_scale: bool = False,
1100 | interpolation: str = 'bicubic',
1101 | antialias: bool = True,
1102 | ) -> Dict[str, torch.Tensor]:
1103 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
1104 | import re
1105 | out_dict = {}
1106 | state_dict = state_dict.get('model', state_dict)
1107 | state_dict = state_dict.get('state_dict', state_dict)
1108 | prefix = ''
1109 |
1110 | if 'visual.class_embedding' in state_dict:
1111 | state_dict = _convert_openai_clip(state_dict, model)
1112 | elif 'module.visual.class_embedding' in state_dict:
1113 | state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
1114 | elif "mask_token" in state_dict:
1115 | state_dict = _convert_dinov2(state_dict, model)
1116 | elif "encoder" in state_dict:
1117 | # IJEPA, vit in an 'encoder' submodule
1118 | state_dict = state_dict['encoder']
1119 | prefix = 'module.'
1120 | elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
1121 | # OpenCLIP model with timm vision encoder
1122 | prefix = 'visual.trunk.'
1123 | if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
1124 | # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
1125 | out_dict['head.weight'] = state_dict['visual.head.proj.weight']
1126 | out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
1127 |
1128 | if prefix:
1129 | # filter on & remove prefix string from keys
1130 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
1131 |
1132 | for k, v in state_dict.items():
1133 | if 'patch_embed.proj.weight' in k:
1134 | O, I, H, W = model.patch_embed.proj.weight.shape
1135 | if len(v.shape) < 4:
1136 | # For old models that I trained prior to conv based patchification
1137 | O, I, H, W = model.patch_embed.proj.weight.shape
1138 | v = v.reshape(O, -1, H, W)
1139 | if v.shape[-1] != W or v.shape[-2] != H:
1140 | v = resample_patch_embed(
1141 | v,
1142 | (H, W),
1143 | interpolation=interpolation,
1144 | antialias=antialias,
1145 | verbose=True,
1146 | )
1147 | elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
1148 | # To resize pos embedding when using model at different size from pretrained weights
1149 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
1150 | v = resample_abs_pos_embed(
1151 | v,
1152 | new_size=model.patch_embed.grid_size,
1153 | num_prefix_tokens=num_prefix_tokens,
1154 | interpolation=interpolation,
1155 | antialias=antialias,
1156 | verbose=True,
1157 | )
1158 | elif adapt_layer_scale and 'gamma_' in k:
1159 | # remap layer-scale gamma into sub-module (deit3 models)
1160 | k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
1161 | elif 'pre_logits' in k:
1162 | # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
1163 | continue
1164 | out_dict[k] = v
1165 | return out_dict
1166 |
1167 |
1168 | def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1169 | return {
1170 | 'url': url,
1171 | 'num_classes': 1000,
1172 | 'input_size': (3, 224, 224),
1173 | 'pool_size': None,
1174 | 'crop_pct': 0.875,
1175 | 'interpolation': 'bicubic',
1176 | 'fixed_input_size': True,
1177 | 'mean': IMAGENET_DEFAULT_MEAN,
1178 | 'std': IMAGENET_DEFAULT_STD,
1179 | 'first_conv': 'patch_embed.proj',
1180 | 'classifier': 'head',
1181 | **kwargs,
1182 | }
1183 |
1184 | default_cfgs = {
1185 | 'kat_tiny_patch16_224': _cfg(
1186 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'),
1187 | 'kat_tiny_patch16_224.vitft': _cfg(
1188 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth'),
1189 | 'kat_small_patch16_224': _cfg(
1190 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth'),
1191 | 'kat_small_patch16_224.vitft': _cfg(
1192 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth'),
1193 | 'kat_base_patch16_224': _cfg(
1194 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth'),
1195 | 'kat_base_patch16_224.vitft': _cfg(
1196 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth')
1197 | }
1198 |
1199 |
1200 | def _create_kat_transformer(variant: str, pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1201 | out_indices = kwargs.pop('out_indices', 3)
1202 | if 'flexi' in variant:
1203 | # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
1204 | # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
1205 | _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
1206 | else:
1207 | _filter_fn = checkpoint_filter_fn
1208 |
1209 | # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
1210 | strict = True
1211 | if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
1212 | strict = False
1213 |
1214 | return build_model_with_cfg(
1215 | KATVisionTransformer,
1216 | variant,
1217 | pretrained,
1218 | pretrained_filter_fn=_filter_fn,
1219 | pretrained_strict=strict,
1220 | feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
1221 | **kwargs,
1222 | )
1223 |
1224 | @register_model
1225 | def kat_tiny_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1226 | """ KAT-Tiny with rational activations (ViT-S/16)
1227 | """
1228 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3,
1229 | act_layer=KAT_Group,
1230 | act_init='swish',
1231 | mlp_layer=KAN,
1232 | weight_init="kan_mimetic")
1233 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1234 | return model
1235 |
1236 | @register_model
1237 | def kat_tiny_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1238 | """ KAT-Tiny with rational activations (ViT-S/16)
1239 | """
1240 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3,
1241 | act_layer=KAT_Group,
1242 | act_init='gelu',
1243 | mlp_layer=KAN,
1244 | weight_init="kan_mimetic")
1245 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1246 | return model
1247 |
1248 | @register_model
1249 | def kat_tiny_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1250 | """ KAT-Tiny with rational activations (ViT-Ti/16)
1251 | """
1252 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3,
1253 | act_layer=KAT_Group,
1254 | act_init='swish',
1255 | mlp_layer=KAN,
1256 | weight_init="kan_mimetic")
1257 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1258 | return model
1259 |
1260 | @register_model
1261 | def kat_small_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1262 | """ KAT-Small with rational activations (ViT-S/16)"""
1263 | model_args = dict(patch_size=16,
1264 | embed_dim=384,
1265 | depth=12,
1266 | num_heads=6,
1267 | act_init='swish',
1268 | act_layer=KAT_Group,
1269 | mlp_layer=KAN,
1270 | weight_init="kan_mimetic")
1271 | model = _create_kat_transformer('kat_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1272 | return model
1273 |
1274 | @register_model
1275 | def kat_small_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1276 | """ KAT-Small with rational activations (ViT-S/16)"""
1277 | model_args = dict(patch_size=16,
1278 | embed_dim=384,
1279 | depth=12,
1280 | num_heads=6,
1281 | act_init='gelu',
1282 | act_layer=KAT_Group,
1283 | mlp_layer=KAN, weight_init="kan_mimetic") # , init_values=1e-5
1284 | model = _create_kat_transformer('kat_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1285 | return model
1286 |
1287 | @register_model
1288 | def kat_small_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1289 | """ KAT-Small with rational activations (ViT-S/16)"""
1290 | model_args = dict(patch_size=16,
1291 | embed_dim=384,
1292 | depth=12, num_heads=6,
1293 | act_init='swish',
1294 | act_layer=KAT_Group,
1295 | mlp_layer=KAN, weight_init="kan_mimetic") # , init_values=1e-5
1296 | model = _create_kat_transformer('kat', pretrained=pretrained, **dict(model_args, **kwargs))
1297 | return model
1298 |
1299 | @register_model
1300 | def kat_base_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1301 | """ KAT-Base with rational activations (ViT-B/16)"""
1302 | model_args = dict(patch_size=16,
1303 | embed_dim=768,
1304 | depth=12,
1305 | num_heads=12,
1306 | act_layer=KAT_Group,
1307 | mlp_layer=KAN,
1308 | act_init='swish',
1309 | weight_init="kan_mimetic")
1310 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1311 | return model
1312 |
1313 | @register_model
1314 | def kat_base_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1315 | """ KAT-Base with rational activations (ViT-B/16)"""
1316 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12,
1317 | act_layer=KAT_Group,
1318 | mlp_layer=KAN,
1319 | act_init='gelu',
1320 | weight_init="kan_mimetic")
1321 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1322 | return model
1323 |
1324 | @register_model
1325 | def kat_base_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer:
1326 | """ KAT-Base with rational activations (ViT-B/16)"""
1327 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12,
1328 | act_layer=KAT_Group,
1329 | mlp_layer=KAN,
1330 | act_init='swish',
1331 | weight_init="kan_mimetic")
1332 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1333 | return model
1334 |
--------------------------------------------------------------------------------
/scripts/train_kat_base_8x128.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_base_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.4 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_base_swish_patch16_224 \
25 | --log-wandb
--------------------------------------------------------------------------------
/scripts/train_kat_base_8x128_vitft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_base_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.4 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_base_swish_patch16_224 \
25 | --initial-checkpoint ./checkpoints/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz \
26 | --log-wandb
--------------------------------------------------------------------------------
/scripts/train_kat_small_8x128.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_small_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.1 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_small_swish_patch16_224 \
25 | --log-wandb
--------------------------------------------------------------------------------
/scripts/train_kat_small_8x128_vitft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_small_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.1 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_small_swish_patch16_224 \
25 | --pretrained-path ./checkpoints/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz \
26 | --pretrained \
27 | --log-wandb
--------------------------------------------------------------------------------
/scripts/train_kat_tiny_8x128.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_tiny_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.1 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_tiny_swish_patch16_224 \
25 | --log-wandb
26 |
--------------------------------------------------------------------------------
/scripts/train_kat_tiny_8x128_vitft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DATA_PATH=/local_home/dataset/imagenet/
3 |
4 | bash ./dist_train.sh 8 $DATA_PATH \
5 | --model kat_tiny_swish_patch16_224 \
6 | -b 128 \
7 | --opt adamw \
8 | --lr 1e-3 \
9 | --weight-decay 0.05 \
10 | --epochs 300 \
11 | --mixup 0.8 \
12 | --cutmix 1.0 \
13 | --sched cosine \
14 | --smoothing 0.1 \
15 | --drop-path 0.1 \
16 | --aa rand-m9-mstd0.5 \
17 | --remode pixel --reprob 0.25 \
18 | --amp \
19 | --crop-pct 0.875 \
20 | --mean 0.485 0.456 0.406 \
21 | --std 0.229 0.224 0.225 \
22 | --model-ema \
23 | --model-ema-decay 0.9999 \
24 | --output output/kat_tiny_swish_patch16_224 \
25 | --initial-checkpoint ./checkpoints/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz \
26 | --log-wandb
--------------------------------------------------------------------------------
/tools/calculate_flops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import katransformer
4 | import timm
5 | from fvcore.nn import FlopCountAnalysis, flop_count_str
6 | from fvcore.nn.jit_handles import get_shape
7 | from math import prod
8 |
9 |
10 | def compute_flops(model_name, input_size, custom_op_name, device='cuda'):
11 | """
12 | Computes the FLOPs for a given TIMM model, including a customized operator, using FlopCountAnalysis.
13 |
14 | Args:
15 | model_name (str): The name of the TIMM model.
16 | input_size (tuple): The input size (height, width) of the model.
17 | custom_op_name (str): The name of the customized operator.
18 | custom_op_flops (int): The number of FLOPs for the customized operator.
19 |
20 | Returns:
21 | int: The total number of FLOPs.
22 | """
23 |
24 | # Load the TIMM model
25 | model = timm.create_model(model_name, pretrained=False)
26 |
27 | # Ensure the model is in evaluation mode to avoid unnecessary computations
28 | model.eval()
29 | model.to(device)
30 |
31 | # Create a dummy input tensor with the specified size
32 | input_tensor = torch.randn(1, 3, *input_size).to(device)
33 |
34 | # Set the FLOPs for the customized operator
35 | def _custom_op_flops_fn(inputs, outputs):
36 | # Assuming each operation in custom_op involves 'n' flops per input element
37 | n = 21 # This should be adjusted based on what the custom operation does
38 |
39 | input_shape = get_shape(inputs[0])
40 | total_elements = prod(input_shape)
41 | return total_elements * n
42 |
43 |
44 | # Use FlopCountAnalysis to compute the FLOPs
45 | analysis = FlopCountAnalysis(model,
46 | input_tensor)
47 | analysis.set_op_handle(custom_op_name, _custom_op_flops_fn)
48 |
49 | print(flop_count_str(analysis))
50 | totoal_flops = analysis.total()
51 | # print totoal_flops in GigaFlops
52 | print("Total FLOPs: ", totoal_flops/1e9, "GFLOPs")
53 | total_params = sum(p.numel() for p in model.parameters())
54 | print("Total Params: ", total_params/1e6, "M")
55 |
56 | return totoal_flops / 1e9
57 |
58 |
59 | # return int(flops)
60 |
61 | if __name__ == "__main__":
62 | # model_name = "vit_base_kat_mimetic_patch16_224" # Replace with your desired model name
63 | model_name = 'kat_tiny_gelu_patch16_224' # Replace with your desired model name
64 | input_size = (224, 224) # Replace with your desired input size
65 | custom_op_name = "prim::PythonOp.rational_1dgroup" # Replace with your customized operator name
66 | # custom_op_flops = 1000 # Replace with the actual FLOPs of your customized operator
67 |
68 | flops = compute_flops(model_name, input_size, custom_op_name)
69 | # print(f"FLOPs for {model_name}: {flops}")
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ ImageNet Training Script
3 |
4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch
6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
8 |
9 | This script was started from an early version of the PyTorch ImageNet example
10 | (https://github.com/pytorch/examples/tree/master/imagenet)
11 |
12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
16 | """
17 | import argparse
18 | import importlib
19 | import json
20 | import logging
21 | import os
22 | import time
23 | from collections import OrderedDict
24 | from contextlib import suppress
25 | from datetime import datetime
26 | from functools import partial
27 |
28 | import torch
29 | import torch.nn as nn
30 | import torchvision.utils
31 | import yaml
32 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
33 |
34 | from timm import utils
35 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
36 | from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
37 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
38 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
39 | from timm.optim import create_optimizer_v2, optimizer_kwargs
40 | from timm.scheduler import create_scheduler_v2, scheduler_kwargs
41 | from timm.utils import ApexScaler, NativeScaler
42 |
43 | try:
44 | from apex import amp
45 | from apex.parallel import DistributedDataParallel as ApexDDP
46 | from apex.parallel import convert_syncbn_model
47 | has_apex = True
48 | except ImportError:
49 | has_apex = False
50 |
51 | has_native_amp = False
52 | try:
53 | if getattr(torch.cuda.amp, 'autocast') is not None:
54 | has_native_amp = True
55 | except AttributeError:
56 | pass
57 |
58 | try:
59 | import wandb
60 | has_wandb = True
61 | except ImportError:
62 | has_wandb = False
63 |
64 | try:
65 | from functorch.compile import memory_efficient_fusion
66 | has_functorch = True
67 | except ImportError as e:
68 | has_functorch = False
69 |
70 | import katransformer
71 |
72 | has_compile = hasattr(torch, 'compile')
73 |
74 |
75 | _logger = logging.getLogger('train')
76 |
77 | # The first arg parser parses out only the --config argument, this argument is used to
78 | # load a yaml file containing key-values that override the defaults for the main parser below
79 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
80 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
81 | help='YAML config file specifying default arguments')
82 |
83 |
84 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
85 |
86 | # Dataset parameters
87 | group = parser.add_argument_group('Dataset parameters')
88 | # Keep this argument outside the dataset group because it is positional.
89 | parser.add_argument('data', nargs='?', metavar='DIR', const=None,
90 | help='path to dataset (positional is *deprecated*, use --data-dir)')
91 | parser.add_argument('--data-dir', metavar='DIR',
92 | help='path to dataset (root dir)')
93 | parser.add_argument('--dataset', metavar='NAME', default='',
94 | help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)')
95 | group.add_argument('--train-split', metavar='NAME', default='train',
96 | help='dataset train split (default: train)')
97 | group.add_argument('--val-split', metavar='NAME', default='validation',
98 | help='dataset validation split (default: validation)')
99 | parser.add_argument('--train-num-samples', default=None, type=int,
100 | metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
101 | parser.add_argument('--val-num-samples', default=None, type=int,
102 | metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
103 | group.add_argument('--dataset-download', action='store_true', default=False,
104 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
105 | group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
106 | help='path to class to idx mapping file (default: "")')
107 | group.add_argument('--input-img-mode', default=None, type=str,
108 | help='Dataset image conversion mode for input images.')
109 | group.add_argument('--input-key', default=None, type=str,
110 | help='Dataset key for input images.')
111 | group.add_argument('--target-key', default=None, type=str,
112 | help='Dataset key for target labels.')
113 |
114 | # Model parameters
115 | group = parser.add_argument_group('Model parameters')
116 | group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
117 | help='Name of model to train (default: "resnet50")')
118 | group.add_argument('--pretrained', action='store_true', default=False,
119 | help='Start with pretrained version of specified network (if avail)')
120 | group.add_argument('--pretrained-path', default=None, type=str,
121 | help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
122 | group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
123 | help='Load this checkpoint into model after initialization (default: none)')
124 | group.add_argument('--resume', default='', type=str, metavar='PATH',
125 | help='Resume full model and optimizer state from checkpoint (default: none)')
126 | group.add_argument('--no-resume-opt', action='store_true', default=False,
127 | help='prevent resume of optimizer state when resuming model')
128 | group.add_argument('--num-classes', type=int, default=None, metavar='N',
129 | help='number of label classes (Model default if None)')
130 | group.add_argument('--gp', default=None, type=str, metavar='POOL',
131 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
132 | group.add_argument('--img-size', type=int, default=None, metavar='N',
133 | help='Image size (default: None => model default)')
134 | group.add_argument('--in-chans', type=int, default=None, metavar='N',
135 | help='Image input channels (default: None => 3)')
136 | group.add_argument('--input-size', default=None, nargs=3, type=int,
137 | metavar='N N N',
138 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
139 | group.add_argument('--crop-pct', default=None, type=float,
140 | metavar='N', help='Input image center crop percent (for validation only)')
141 | group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
142 | help='Override mean pixel value of dataset')
143 | group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
144 | help='Override std deviation of dataset')
145 | group.add_argument('--interpolation', default='', type=str, metavar='NAME',
146 | help='Image resize interpolation type (overrides model)')
147 | group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
148 | help='Input batch size for training (default: 128)')
149 | group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
150 | help='Validation batch size override (default: None)')
151 | group.add_argument('--channels-last', action='store_true', default=False,
152 | help='Use channels_last memory layout')
153 | group.add_argument('--fuser', default='', type=str,
154 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
155 | group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
156 | help='The number of steps to accumulate gradients (default: 1)')
157 | group.add_argument('--grad-checkpointing', action='store_true', default=False,
158 | help='Enable gradient checkpointing through model blocks/stages')
159 | group.add_argument('--fast-norm', default=False, action='store_true',
160 | help='enable experimental fast-norm')
161 | group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
162 | group.add_argument('--head-init-scale', default=None, type=float,
163 | help='Head initialization scale')
164 | group.add_argument('--head-init-bias', default=None, type=float,
165 | help='Head initialization bias value')
166 |
167 | # scripting / codegen
168 | scripting_group = group.add_mutually_exclusive_group()
169 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
170 | help='torch.jit.script the full model')
171 | scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
172 | help="Enable compilation w/ specified backend (default: inductor).")
173 |
174 | # Device & distributed
175 | group = parser.add_argument_group('Device parameters')
176 | group.add_argument('--device', default='cuda', type=str,
177 | help="Device (accelerator) to use.")
178 | group.add_argument('--amp', action='store_true', default=False,
179 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
180 | group.add_argument('--amp-dtype', default='float16', type=str,
181 | help='lower precision AMP dtype (default: float16)')
182 | group.add_argument('--amp-impl', default='native', type=str,
183 | help='AMP impl to use, "native" or "apex" (default: native)')
184 | group.add_argument('--no-ddp-bb', action='store_true', default=False,
185 | help='Force broadcast buffers for native DDP to off.')
186 | group.add_argument('--synchronize-step', action='store_true', default=False,
187 | help='torch.cuda.synchronize() end of each step')
188 | group.add_argument("--local_rank", default=0, type=int)
189 | parser.add_argument('--device-modules', default=None, type=str, nargs='+',
190 | help="Python imports for device backend modules.")
191 |
192 | # Optimizer parameters
193 | group = parser.add_argument_group('Optimizer parameters')
194 | group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
195 | help='Optimizer (default: "sgd")')
196 | group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
197 | help='Optimizer Epsilon (default: None, use opt default)')
198 | group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
199 | help='Optimizer Betas (default: None, use opt default)')
200 | group.add_argument('--momentum', type=float, default=0.9, metavar='M',
201 | help='Optimizer momentum (default: 0.9)')
202 | group.add_argument('--weight-decay', type=float, default=2e-5,
203 | help='weight decay (default: 2e-5)')
204 | group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
205 | help='Clip gradient norm (default: None, no clipping)')
206 | group.add_argument('--clip-mode', type=str, default='norm',
207 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
208 | group.add_argument('--layer-decay', type=float, default=None,
209 | help='layer-wise learning rate decay (default: None)')
210 | group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
211 |
212 | # Learning rate schedule parameters
213 | group = parser.add_argument_group('Learning rate schedule parameters')
214 | group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
215 | help='LR scheduler (default: "step"')
216 | group.add_argument('--sched-on-updates', action='store_true', default=False,
217 | help='Apply LR scheduler step on update instead of epoch end.')
218 | group.add_argument('--lr', type=float, default=None, metavar='LR',
219 | help='learning rate, overrides lr-base if set (default: None)')
220 | group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
221 | help='base learning rate: lr = lr_base * global_batch_size / base_size')
222 | group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
223 | help='base learning rate batch size (divisor, default: 256).')
224 | group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
225 | help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
226 | group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
227 | help='learning rate noise on/off epoch percentages')
228 | group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
229 | help='learning rate noise limit percent (default: 0.67)')
230 | group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
231 | help='learning rate noise std-dev (default: 1.0)')
232 | group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
233 | help='learning rate cycle len multiplier (default: 1.0)')
234 | group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
235 | help='amount to decay each learning rate cycle (default: 0.5)')
236 | group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
237 | help='learning rate cycle limit, cycles enabled if > 1')
238 | group.add_argument('--lr-k-decay', type=float, default=1.0,
239 | help='learning rate k-decay for cosine/poly (default: 1.0)')
240 | group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
241 | help='warmup learning rate (default: 1e-5)')
242 | group.add_argument('--min-lr', type=float, default=0, metavar='LR',
243 | help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
244 | group.add_argument('--epochs', type=int, default=300, metavar='N',
245 | help='number of epochs to train (default: 300)')
246 | group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
247 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
248 | group.add_argument('--start-epoch', default=None, type=int, metavar='N',
249 | help='manual epoch number (useful on restarts)')
250 | group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
251 | help='list of decay epoch indices for multistep lr. must be increasing')
252 | group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
253 | help='epoch interval to decay LR')
254 | group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
255 | help='epochs to warmup LR, if scheduler supports')
256 | group.add_argument('--warmup-prefix', action='store_true', default=False,
257 | help='Exclude warmup period from decay schedule.'),
258 | group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
259 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
260 | group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
261 | help='patience epochs for Plateau LR scheduler (default: 10)')
262 | group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
263 | help='LR decay rate (default: 0.1)')
264 |
265 | # Augmentation & regularization parameters
266 | group = parser.add_argument_group('Augmentation and regularization parameters')
267 | group.add_argument('--no-aug', action='store_true', default=False,
268 | help='Disable all training augmentation, override other train aug args')
269 | group.add_argument('--train-crop-mode', type=str, default=None,
270 | help='Crop-mode in train'),
271 | group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
272 | help='Random resize scale (default: 0.08 1.0)')
273 | group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
274 | help='Random resize aspect ratio (default: 0.75 1.33)')
275 | group.add_argument('--hflip', type=float, default=0.5,
276 | help='Horizontal flip training aug probability')
277 | group.add_argument('--vflip', type=float, default=0.,
278 | help='Vertical flip training aug probability')
279 | group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
280 | help='Color jitter factor (default: 0.4)')
281 | group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
282 | help='Probability of applying any color jitter.')
283 | group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
284 | help='Probability of applying random grayscale conversion.')
285 | group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
286 | help='Probability of applying gaussian blur.')
287 | group.add_argument('--aa', type=str, default=None, metavar='NAME',
288 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
289 | group.add_argument('--aug-repeats', type=float, default=0,
290 | help='Number of augmentation repetitions (distributed training only) (default: 0)')
291 | group.add_argument('--aug-splits', type=int, default=0,
292 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
293 | group.add_argument('--jsd-loss', action='store_true', default=False,
294 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
295 | group.add_argument('--bce-loss', action='store_true', default=False,
296 | help='Enable BCE loss w/ Mixup/CutMix use.')
297 | group.add_argument('--bce-sum', action='store_true', default=False,
298 | help='Sum over classes when using BCE loss.')
299 | group.add_argument('--bce-target-thresh', type=float, default=None,
300 | help='Threshold for binarizing softened BCE targets (default: None, disabled).')
301 | group.add_argument('--bce-pos-weight', type=float, default=None,
302 | help='Positive weighting for BCE loss.')
303 | group.add_argument('--reprob', type=float, default=0., metavar='PCT',
304 | help='Random erase prob (default: 0.)')
305 | group.add_argument('--remode', type=str, default='pixel',
306 | help='Random erase mode (default: "pixel")')
307 | group.add_argument('--recount', type=int, default=1,
308 | help='Random erase count (default: 1)')
309 | group.add_argument('--resplit', action='store_true', default=False,
310 | help='Do not random erase first (clean) augmentation split')
311 | group.add_argument('--mixup', type=float, default=0.0,
312 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
313 | group.add_argument('--cutmix', type=float, default=0.0,
314 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
315 | group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
316 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
317 | group.add_argument('--mixup-prob', type=float, default=1.0,
318 | help='Probability of performing mixup or cutmix when either/both is enabled')
319 | group.add_argument('--mixup-switch-prob', type=float, default=0.5,
320 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
321 | group.add_argument('--mixup-mode', type=str, default='batch',
322 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
323 | group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
324 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
325 | group.add_argument('--smoothing', type=float, default=0.1,
326 | help='Label smoothing (default: 0.1)')
327 | group.add_argument('--train-interpolation', type=str, default='random',
328 | help='Training interpolation (random, bilinear, bicubic default: "random")')
329 | group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
330 | help='Dropout rate (default: 0.)')
331 | group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
332 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
333 | group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
334 | help='Drop path rate (default: None)')
335 | group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
336 | help='Drop block rate (default: None)')
337 |
338 | # Batch norm parameters (only works with gen_efficientnet based models currently)
339 | group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
340 | group.add_argument('--bn-momentum', type=float, default=None,
341 | help='BatchNorm momentum override (if not None)')
342 | group.add_argument('--bn-eps', type=float, default=None,
343 | help='BatchNorm epsilon override (if not None)')
344 | group.add_argument('--sync-bn', action='store_true',
345 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
346 | group.add_argument('--dist-bn', type=str, default='reduce',
347 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
348 | group.add_argument('--split-bn', action='store_true',
349 | help='Enable separate BN layers per augmentation split.')
350 |
351 | # Model Exponential Moving Average
352 | group = parser.add_argument_group('Model exponential moving average parameters')
353 | group.add_argument('--model-ema', action='store_true', default=False,
354 | help='Enable tracking moving average of model weights.')
355 | group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
356 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
357 | group.add_argument('--model-ema-decay', type=float, default=0.9998,
358 | help='Decay factor for model weights moving average (default: 0.9998)')
359 | group.add_argument('--model-ema-warmup', action='store_true',
360 | help='Enable warmup for model EMA decay.')
361 |
362 | # Misc
363 | group = parser.add_argument_group('Miscellaneous parameters')
364 | group.add_argument('--seed', type=int, default=42, metavar='S',
365 | help='random seed (default: 42)')
366 | group.add_argument('--worker-seeding', type=str, default='all',
367 | help='worker seed mode (default: all)')
368 | group.add_argument('--log-interval', type=int, default=50, metavar='N',
369 | help='how many batches to wait before logging training status')
370 | group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
371 | help='how many batches to wait before writing recovery checkpoint')
372 | group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
373 | help='number of checkpoints to keep (default: 10)')
374 | group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
375 | help='how many training processes to use (default: 4)')
376 | group.add_argument('--save-images', action='store_true', default=False,
377 | help='save images of input bathes every log interval for debugging')
378 | group.add_argument('--pin-mem', action='store_true', default=False,
379 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
380 | group.add_argument('--no-prefetcher', action='store_true', default=False,
381 | help='disable fast prefetcher')
382 | group.add_argument('--output', default='', type=str, metavar='PATH',
383 | help='path to output folder (default: none, current dir)')
384 | group.add_argument('--experiment', default='', type=str, metavar='NAME',
385 | help='name of train experiment, name of sub-folder for output')
386 | group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
387 | help='Best metric (default: "top1"')
388 | group.add_argument('--tta', type=int, default=0, metavar='N',
389 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
390 | group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
391 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
392 | group.add_argument('--log-wandb', action='store_true', default=False,
393 | help='log training and validation metrics to wandb')
394 |
395 |
396 | def _parse_args():
397 | # Do we have a config file to parse?
398 | args_config, remaining = config_parser.parse_known_args()
399 | if args_config.config:
400 | with open(args_config.config, 'r') as f:
401 | cfg = yaml.safe_load(f)
402 | parser.set_defaults(**cfg)
403 |
404 | # The main arg parser parses the rest of the args, the usual
405 | # defaults will have been overridden if config file specified.
406 | args = parser.parse_args(remaining)
407 |
408 | # Cache the args as a text string to save them in the output dir later
409 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
410 | return args, args_text
411 |
412 |
413 | def main():
414 | utils.setup_default_logging()
415 | args, args_text = _parse_args()
416 |
417 | if args.device_modules:
418 | for module in args.device_modules:
419 | importlib.import_module(module)
420 |
421 | if torch.cuda.is_available():
422 | torch.backends.cuda.matmul.allow_tf32 = True
423 | torch.backends.cudnn.benchmark = True
424 |
425 | args.prefetcher = not args.no_prefetcher
426 | args.grad_accum_steps = max(1, args.grad_accum_steps)
427 | device = utils.init_distributed_device(args)
428 | if args.distributed:
429 | _logger.info(
430 | 'Training in distributed mode with multiple processes, 1 device per process.'
431 | f'Process {args.rank}, total {args.world_size}, device {args.device}.')
432 | else:
433 | _logger.info(f'Training with a single process on 1 device ({args.device}).')
434 | assert args.rank >= 0
435 |
436 | # resolve AMP arguments based on PyTorch / Apex availability
437 | use_amp = None
438 | amp_dtype = torch.float16
439 |
440 | if args.amp:
441 | if args.amp_impl == 'apex':
442 | assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
443 | use_amp = 'apex'
444 | assert args.amp_dtype == 'float16'
445 | else:
446 | assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
447 | use_amp = 'native'
448 | assert args.amp_dtype in ('float16', 'bfloat16')
449 | if args.amp_dtype == 'bfloat16':
450 | amp_dtype = torch.bfloat16
451 |
452 | utils.random_seed(args.seed, args.rank)
453 |
454 | if args.fuser:
455 | utils.set_jit_fuser(args.fuser)
456 | if args.fast_norm:
457 | set_fast_norm()
458 |
459 | in_chans = 3
460 | if args.in_chans is not None:
461 | in_chans = args.in_chans
462 | elif args.input_size is not None:
463 | in_chans = args.input_size[0]
464 |
465 | factory_kwargs = {}
466 | if args.pretrained_path:
467 | # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
468 | factory_kwargs['pretrained_cfg_overlay'] = dict(
469 | file=args.pretrained_path,
470 | num_classes=-1, # force head adaptation
471 | )
472 |
473 | model = create_model(
474 | args.model,
475 | pretrained=args.pretrained,
476 | in_chans=in_chans,
477 | num_classes=args.num_classes,
478 | drop_rate=args.drop,
479 | drop_path_rate=args.drop_path,
480 | drop_block_rate=args.drop_block,
481 | global_pool=args.gp,
482 | bn_momentum=args.bn_momentum,
483 | bn_eps=args.bn_eps,
484 | scriptable=args.torchscript,
485 | checkpoint_path=args.initial_checkpoint,
486 | **factory_kwargs,
487 | **args.model_kwargs,
488 | )
489 | if args.head_init_scale is not None:
490 | with torch.no_grad():
491 | model.get_classifier().weight.mul_(args.head_init_scale)
492 | model.get_classifier().bias.mul_(args.head_init_scale)
493 | if args.head_init_bias is not None:
494 | nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
495 |
496 | if args.num_classes is None:
497 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
498 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
499 |
500 | if args.grad_checkpointing:
501 | model.set_grad_checkpointing(enable=True)
502 |
503 | if utils.is_primary(args):
504 | _logger.info(
505 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
506 | print(model)
507 |
508 | data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
509 |
510 | # setup augmentation batch splits for contrastive loss or split bn
511 | num_aug_splits = 0
512 | if args.aug_splits > 0:
513 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
514 | num_aug_splits = args.aug_splits
515 |
516 | # enable split bn (separate bn stats per batch-portion)
517 | if args.split_bn:
518 | assert num_aug_splits > 1 or args.resplit
519 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
520 |
521 | # move model to GPU, enable channels last layout if set
522 | model.to(device=device)
523 | if args.channels_last:
524 | model.to(memory_format=torch.channels_last)
525 |
526 | # setup synchronized BatchNorm for distributed training
527 | if args.distributed and args.sync_bn:
528 | args.dist_bn = '' # disable dist_bn when sync BN active
529 | assert not args.split_bn
530 | if has_apex and use_amp == 'apex':
531 | # Apex SyncBN used with Apex AMP
532 | # WARNING this won't currently work with models using BatchNormAct2d
533 | model = convert_syncbn_model(model)
534 | else:
535 | model = convert_sync_batchnorm(model)
536 | if utils.is_primary(args):
537 | _logger.info(
538 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
539 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
540 |
541 | if args.torchscript:
542 | assert not args.torchcompile
543 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
544 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
545 | model = torch.jit.script(model)
546 |
547 | if not args.lr:
548 | global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
549 | batch_ratio = global_batch_size / args.lr_base_size
550 | if not args.lr_base_scale:
551 | on = args.opt.lower()
552 | args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
553 | if args.lr_base_scale == 'sqrt':
554 | batch_ratio = batch_ratio ** 0.5
555 | args.lr = args.lr_base * batch_ratio
556 | if utils.is_primary(args):
557 | _logger.info(
558 | f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
559 | f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
560 |
561 | optimizer = create_optimizer_v2(
562 | model,
563 | **optimizer_kwargs(cfg=args),
564 | **args.opt_kwargs,
565 | )
566 |
567 | # setup automatic mixed-precision (AMP) loss scaling and op casting
568 | amp_autocast = suppress # do nothing
569 | loss_scaler = None
570 | if use_amp == 'apex':
571 | assert device.type == 'cuda'
572 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
573 | loss_scaler = ApexScaler()
574 | if utils.is_primary(args):
575 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
576 | elif use_amp == 'native':
577 | try:
578 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
579 | except (AttributeError, TypeError):
580 | # fallback to CUDA only AMP for PyTorch < 1.10
581 | assert device.type == 'cuda'
582 | amp_autocast = torch.cuda.amp.autocast
583 | if device.type == 'cuda' and amp_dtype == torch.float16:
584 | # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
585 | loss_scaler = NativeScaler()
586 | if utils.is_primary(args):
587 | _logger.info('Using native Torch AMP. Training in mixed precision.')
588 | else:
589 | if utils.is_primary(args):
590 | _logger.info('AMP not enabled. Training in float32.')
591 |
592 | # optionally resume from a checkpoint
593 | resume_epoch = None
594 | if args.resume:
595 | resume_epoch = resume_checkpoint(
596 | model,
597 | args.resume,
598 | optimizer=None if args.no_resume_opt else optimizer,
599 | loss_scaler=None if args.no_resume_opt else loss_scaler,
600 | log_info=utils.is_primary(args),
601 | )
602 |
603 | # setup exponential moving average of model weights, SWA could be used here too
604 | model_ema = None
605 | if args.model_ema:
606 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
607 | model_ema = utils.ModelEmaV3(
608 | model,
609 | decay=args.model_ema_decay,
610 | use_warmup=args.model_ema_warmup,
611 | device='cpu' if args.model_ema_force_cpu else None,
612 | )
613 | if args.resume:
614 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
615 | if args.torchcompile:
616 | model_ema = torch.compile(model_ema, backend=args.torchcompile)
617 |
618 | # setup distributed training
619 | if args.distributed:
620 | if has_apex and use_amp == 'apex':
621 | # Apex DDP preferred unless native amp is activated
622 | if utils.is_primary(args):
623 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
624 | model = ApexDDP(model, delay_allreduce=True)
625 | else:
626 | if utils.is_primary(args):
627 | _logger.info("Using native Torch DistributedDataParallel.")
628 | model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
629 | # NOTE: EMA model does not need to be wrapped by DDP
630 |
631 | if args.torchcompile:
632 | # torch compile should be done after DDP
633 | assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
634 | model = torch.compile(model, backend=args.torchcompile)
635 |
636 | # create the train and eval datasets
637 | if args.data and not args.data_dir:
638 | args.data_dir = args.data
639 | if args.input_img_mode is None:
640 | input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
641 | else:
642 | input_img_mode = args.input_img_mode
643 |
644 | dataset_train = create_dataset(
645 | args.dataset,
646 | root=args.data_dir,
647 | split=args.train_split,
648 | is_training=True,
649 | class_map=args.class_map,
650 | download=args.dataset_download,
651 | batch_size=args.batch_size,
652 | seed=args.seed,
653 | repeats=args.epoch_repeats,
654 | input_img_mode=input_img_mode,
655 | input_key=args.input_key,
656 | target_key=args.target_key,
657 | num_samples=args.train_num_samples,
658 | )
659 |
660 | if args.val_split:
661 | dataset_eval = create_dataset(
662 | args.dataset,
663 | root=args.data_dir,
664 | split=args.val_split,
665 | is_training=False,
666 | class_map=args.class_map,
667 | download=args.dataset_download,
668 | batch_size=args.batch_size,
669 | input_img_mode=input_img_mode,
670 | input_key=args.input_key,
671 | target_key=args.target_key,
672 | num_samples=args.val_num_samples,
673 | )
674 |
675 | # setup mixup / cutmix
676 | collate_fn = None
677 | mixup_fn = None
678 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
679 | if mixup_active:
680 | mixup_args = dict(
681 | mixup_alpha=args.mixup,
682 | cutmix_alpha=args.cutmix,
683 | cutmix_minmax=args.cutmix_minmax,
684 | prob=args.mixup_prob,
685 | switch_prob=args.mixup_switch_prob,
686 | mode=args.mixup_mode,
687 | label_smoothing=args.smoothing,
688 | num_classes=args.num_classes
689 | )
690 | if args.prefetcher:
691 | assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
692 | collate_fn = FastCollateMixup(**mixup_args)
693 | else:
694 | mixup_fn = Mixup(**mixup_args)
695 |
696 | # wrap dataset in AugMix helper
697 | if num_aug_splits > 1:
698 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
699 |
700 | # create data loaders w/ augmentation pipeline
701 | train_interpolation = args.train_interpolation
702 | if args.no_aug or not train_interpolation:
703 | train_interpolation = data_config['interpolation']
704 | loader_train = create_loader(
705 | dataset_train,
706 | input_size=data_config['input_size'],
707 | batch_size=args.batch_size,
708 | is_training=True,
709 | no_aug=args.no_aug,
710 | re_prob=args.reprob,
711 | re_mode=args.remode,
712 | re_count=args.recount,
713 | re_split=args.resplit,
714 | train_crop_mode=args.train_crop_mode,
715 | scale=args.scale,
716 | ratio=args.ratio,
717 | hflip=args.hflip,
718 | vflip=args.vflip,
719 | color_jitter=args.color_jitter,
720 | color_jitter_prob=args.color_jitter_prob,
721 | grayscale_prob=args.grayscale_prob,
722 | gaussian_blur_prob=args.gaussian_blur_prob,
723 | auto_augment=args.aa,
724 | num_aug_repeats=args.aug_repeats,
725 | num_aug_splits=num_aug_splits,
726 | interpolation=train_interpolation,
727 | mean=data_config['mean'],
728 | std=data_config['std'],
729 | num_workers=args.workers,
730 | distributed=args.distributed,
731 | collate_fn=collate_fn,
732 | pin_memory=args.pin_mem,
733 | device=device,
734 | use_prefetcher=args.prefetcher,
735 | use_multi_epochs_loader=args.use_multi_epochs_loader,
736 | worker_seeding=args.worker_seeding,
737 | )
738 |
739 | loader_eval = None
740 | if args.val_split:
741 | eval_workers = args.workers
742 | if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
743 | # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
744 | eval_workers = min(2, args.workers)
745 | loader_eval = create_loader(
746 | dataset_eval,
747 | input_size=data_config['input_size'],
748 | batch_size=args.validation_batch_size or args.batch_size,
749 | is_training=False,
750 | interpolation=data_config['interpolation'],
751 | mean=data_config['mean'],
752 | std=data_config['std'],
753 | num_workers=eval_workers,
754 | distributed=args.distributed,
755 | crop_pct=data_config['crop_pct'],
756 | pin_memory=args.pin_mem,
757 | device=device,
758 | use_prefetcher=args.prefetcher,
759 | )
760 |
761 | # setup loss function
762 | if args.jsd_loss:
763 | assert num_aug_splits > 1 # JSD only valid with aug splits set
764 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
765 | elif mixup_active:
766 | # smoothing is handled with mixup target transform which outputs sparse, soft targets
767 | if args.bce_loss:
768 | train_loss_fn = BinaryCrossEntropy(
769 | target_threshold=args.bce_target_thresh,
770 | sum_classes=args.bce_sum,
771 | pos_weight=args.bce_pos_weight,
772 | )
773 | else:
774 | train_loss_fn = SoftTargetCrossEntropy()
775 | elif args.smoothing:
776 | if args.bce_loss:
777 | train_loss_fn = BinaryCrossEntropy(
778 | smoothing=args.smoothing,
779 | target_threshold=args.bce_target_thresh,
780 | sum_classes=args.bce_sum,
781 | pos_weight=args.bce_pos_weight,
782 | )
783 | else:
784 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
785 | else:
786 | train_loss_fn = nn.CrossEntropyLoss()
787 | train_loss_fn = train_loss_fn.to(device=device)
788 | validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
789 |
790 | # setup checkpoint saver and eval metric tracking
791 | eval_metric = args.eval_metric if loader_eval is not None else 'loss'
792 | decreasing_metric = eval_metric == 'loss'
793 | best_metric = None
794 | best_epoch = None
795 | saver = None
796 | output_dir = None
797 | if utils.is_primary(args):
798 | if args.experiment:
799 | exp_name = args.experiment
800 | else:
801 | exp_name = '-'.join([
802 | datetime.now().strftime("%Y%m%d-%H%M%S"),
803 | safe_model_name(args.model),
804 | str(data_config['input_size'][-1])
805 | ])
806 | output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
807 | saver = utils.CheckpointSaver(
808 | model=model,
809 | optimizer=optimizer,
810 | args=args,
811 | model_ema=model_ema,
812 | amp_scaler=loss_scaler,
813 | checkpoint_dir=output_dir,
814 | recovery_dir=output_dir,
815 | decreasing=decreasing_metric,
816 | max_history=args.checkpoint_hist
817 | )
818 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
819 | f.write(args_text)
820 |
821 | if utils.is_primary(args) and args.log_wandb:
822 | if has_wandb:
823 | wandb.init(project="scale-kan",
824 | name=exp_name,
825 | config=args)
826 | else:
827 | _logger.warning(
828 | "You've requested to log metrics to wandb but package not found. "
829 | "Metrics not being logged to wandb, try `pip install wandb`")
830 |
831 | # setup learning rate schedule and starting epoch
832 | updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
833 | lr_scheduler, num_epochs = create_scheduler_v2(
834 | optimizer,
835 | **scheduler_kwargs(args, decreasing_metric=decreasing_metric),
836 | updates_per_epoch=updates_per_epoch,
837 | )
838 | start_epoch = 0
839 | if args.start_epoch is not None:
840 | # a specified start_epoch will always override the resume epoch
841 | start_epoch = args.start_epoch
842 | elif resume_epoch is not None:
843 | start_epoch = resume_epoch
844 | if lr_scheduler is not None and start_epoch > 0:
845 | if args.sched_on_updates:
846 | lr_scheduler.step_update(start_epoch * updates_per_epoch)
847 | else:
848 | lr_scheduler.step(start_epoch)
849 |
850 | if utils.is_primary(args):
851 | _logger.info(
852 | f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
853 |
854 | results = []
855 | try:
856 | for epoch in range(start_epoch, num_epochs):
857 | if hasattr(dataset_train, 'set_epoch'):
858 | dataset_train.set_epoch(epoch)
859 | elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
860 | loader_train.sampler.set_epoch(epoch)
861 |
862 | train_metrics = train_one_epoch(
863 | epoch,
864 | model,
865 | loader_train,
866 | optimizer,
867 | train_loss_fn,
868 | args,
869 | lr_scheduler=lr_scheduler,
870 | saver=saver,
871 | output_dir=output_dir,
872 | amp_autocast=amp_autocast,
873 | loss_scaler=loss_scaler,
874 | model_ema=model_ema,
875 | mixup_fn=mixup_fn,
876 | num_updates_total=num_epochs * updates_per_epoch,
877 | )
878 |
879 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
880 | if utils.is_primary(args):
881 | _logger.info("Distributing BatchNorm running means and vars")
882 | utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
883 |
884 | if loader_eval is not None:
885 | eval_metrics = validate(
886 | model,
887 | loader_eval,
888 | validate_loss_fn,
889 | args,
890 | device=device,
891 | amp_autocast=amp_autocast,
892 | )
893 |
894 | if model_ema is not None and not args.model_ema_force_cpu:
895 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
896 | utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
897 |
898 | ema_eval_metrics = validate(
899 | model_ema,
900 | loader_eval,
901 | validate_loss_fn,
902 | args,
903 | device=device,
904 | amp_autocast=amp_autocast,
905 | log_suffix=' (EMA)',
906 | )
907 | eval_metrics = ema_eval_metrics
908 | else:
909 | eval_metrics = None
910 |
911 | if output_dir is not None:
912 | lrs = [param_group['lr'] for param_group in optimizer.param_groups]
913 | utils.update_summary(
914 | epoch,
915 | train_metrics,
916 | eval_metrics,
917 | filename=os.path.join(output_dir, 'summary.csv'),
918 | lr=sum(lrs) / len(lrs),
919 | write_header=best_metric is None,
920 | log_wandb=args.log_wandb and has_wandb,
921 | )
922 |
923 | if eval_metrics is not None:
924 | latest_metric = eval_metrics[eval_metric]
925 | else:
926 | latest_metric = train_metrics[eval_metric]
927 |
928 | if saver is not None:
929 | # save proper checkpoint with eval metric
930 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
931 |
932 | if lr_scheduler is not None:
933 | # step LR for next epoch
934 | lr_scheduler.step(epoch + 1, latest_metric)
935 |
936 | results.append({
937 | 'epoch': epoch,
938 | 'train': train_metrics,
939 | 'validation': eval_metrics,
940 | })
941 |
942 | except KeyboardInterrupt:
943 | pass
944 |
945 | results = {'all': results}
946 | if best_metric is not None:
947 | results['best'] = results['all'][best_epoch - start_epoch]
948 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
949 | print(f'--result\n{json.dumps(results, indent=4)}')
950 |
951 |
952 | def train_one_epoch(
953 | epoch,
954 | model,
955 | loader,
956 | optimizer,
957 | loss_fn,
958 | args,
959 | device=torch.device('cuda'),
960 | lr_scheduler=None,
961 | saver=None,
962 | output_dir=None,
963 | amp_autocast=suppress,
964 | loss_scaler=None,
965 | model_ema=None,
966 | mixup_fn=None,
967 | num_updates_total=None,
968 | ):
969 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
970 | if args.prefetcher and loader.mixup_enabled:
971 | loader.mixup_enabled = False
972 | elif mixup_fn is not None:
973 | mixup_fn.mixup_enabled = False
974 |
975 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
976 | has_no_sync = hasattr(model, "no_sync")
977 | update_time_m = utils.AverageMeter()
978 | data_time_m = utils.AverageMeter()
979 | losses_m = utils.AverageMeter()
980 |
981 | model.train()
982 |
983 | accum_steps = args.grad_accum_steps
984 | last_accum_steps = len(loader) % accum_steps
985 | updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
986 | num_updates = epoch * updates_per_epoch
987 | last_batch_idx = len(loader) - 1
988 | last_batch_idx_to_accum = len(loader) - last_accum_steps
989 |
990 | data_start_time = update_start_time = time.time()
991 | optimizer.zero_grad()
992 | update_sample_count = 0
993 | for batch_idx, (input, target) in enumerate(loader):
994 | last_batch = batch_idx == last_batch_idx
995 | need_update = last_batch or (batch_idx + 1) % accum_steps == 0
996 | update_idx = batch_idx // accum_steps
997 | if batch_idx >= last_batch_idx_to_accum:
998 | accum_steps = last_accum_steps
999 |
1000 | if not args.prefetcher:
1001 | input, target = input.to(device), target.to(device)
1002 | if mixup_fn is not None:
1003 | input, target = mixup_fn(input, target)
1004 | if args.channels_last:
1005 | input = input.contiguous(memory_format=torch.channels_last)
1006 |
1007 | # multiply by accum steps to get equivalent for full update
1008 | data_time_m.update(accum_steps * (time.time() - data_start_time))
1009 |
1010 | def _forward():
1011 | with amp_autocast():
1012 | output = model(input)
1013 | loss = loss_fn(output, target)
1014 | if accum_steps > 1:
1015 | loss /= accum_steps
1016 | return loss
1017 |
1018 | def _backward(_loss):
1019 | if loss_scaler is not None:
1020 | loss_scaler(
1021 | _loss,
1022 | optimizer,
1023 | clip_grad=args.clip_grad,
1024 | clip_mode=args.clip_mode,
1025 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
1026 | create_graph=second_order,
1027 | need_update=need_update,
1028 | )
1029 | else:
1030 | _loss.backward(create_graph=second_order)
1031 | if need_update:
1032 | if args.clip_grad is not None:
1033 | utils.dispatch_clip_grad(
1034 | model_parameters(model, exclude_head='agc' in args.clip_mode),
1035 | value=args.clip_grad,
1036 | mode=args.clip_mode,
1037 | )
1038 | optimizer.step()
1039 |
1040 | if has_no_sync and not need_update:
1041 | with model.no_sync():
1042 | loss = _forward()
1043 | _backward(loss)
1044 | else:
1045 | loss = _forward()
1046 | _backward(loss)
1047 |
1048 | if not args.distributed:
1049 | losses_m.update(loss.item() * accum_steps, input.size(0))
1050 | update_sample_count += input.size(0)
1051 |
1052 | if not need_update:
1053 | data_start_time = time.time()
1054 | continue
1055 |
1056 | num_updates += 1
1057 | optimizer.zero_grad()
1058 | if model_ema is not None:
1059 | model_ema.update(model, step=num_updates)
1060 |
1061 | if args.synchronize_step and device.type == 'cuda':
1062 | torch.cuda.synchronize()
1063 | time_now = time.time()
1064 | update_time_m.update(time.time() - update_start_time)
1065 | update_start_time = time_now
1066 |
1067 | if update_idx % args.log_interval == 0:
1068 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
1069 | lr = sum(lrl) / len(lrl)
1070 |
1071 | if args.distributed:
1072 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1073 | losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
1074 | update_sample_count *= args.world_size
1075 |
1076 | if utils.is_primary(args):
1077 | _logger.info(
1078 | f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
1079 | f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
1080 | f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
1081 | f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
1082 | f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
1083 | f'LR: {lr:.3e} '
1084 | f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
1085 | )
1086 |
1087 | if args.save_images and output_dir:
1088 | torchvision.utils.save_image(
1089 | input,
1090 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
1091 | padding=0,
1092 | normalize=True
1093 | )
1094 |
1095 | if saver is not None and args.recovery_interval and (
1096 | (update_idx + 1) % args.recovery_interval == 0):
1097 | saver.save_recovery(epoch, batch_idx=update_idx)
1098 |
1099 | if lr_scheduler is not None:
1100 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
1101 |
1102 | update_sample_count = 0
1103 | data_start_time = time.time()
1104 | # end for
1105 |
1106 | if hasattr(optimizer, 'sync_lookahead'):
1107 | optimizer.sync_lookahead()
1108 |
1109 | return OrderedDict([('loss', losses_m.avg)])
1110 |
1111 |
1112 | def validate(
1113 | model,
1114 | loader,
1115 | loss_fn,
1116 | args,
1117 | device=torch.device('cuda'),
1118 | amp_autocast=suppress,
1119 | log_suffix=''
1120 | ):
1121 | batch_time_m = utils.AverageMeter()
1122 | losses_m = utils.AverageMeter()
1123 | top1_m = utils.AverageMeter()
1124 | top5_m = utils.AverageMeter()
1125 |
1126 | model.eval()
1127 |
1128 | end = time.time()
1129 | last_idx = len(loader) - 1
1130 | with torch.no_grad():
1131 | for batch_idx, (input, target) in enumerate(loader):
1132 | last_batch = batch_idx == last_idx
1133 | if not args.prefetcher:
1134 | input = input.to(device)
1135 | target = target.to(device)
1136 | if args.channels_last:
1137 | input = input.contiguous(memory_format=torch.channels_last)
1138 |
1139 | with amp_autocast():
1140 | output = model(input)
1141 | if isinstance(output, (tuple, list)):
1142 | output = output[0]
1143 |
1144 | # augmentation reduction
1145 | reduce_factor = args.tta
1146 | if reduce_factor > 1:
1147 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
1148 | target = target[0:target.size(0):reduce_factor]
1149 |
1150 | loss = loss_fn(output, target)
1151 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
1152 |
1153 | if args.distributed:
1154 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1155 | acc1 = utils.reduce_tensor(acc1, args.world_size)
1156 | acc5 = utils.reduce_tensor(acc5, args.world_size)
1157 | else:
1158 | reduced_loss = loss.data
1159 |
1160 | if device.type == 'cuda':
1161 | torch.cuda.synchronize()
1162 |
1163 | losses_m.update(reduced_loss.item(), input.size(0))
1164 | top1_m.update(acc1.item(), output.size(0))
1165 | top5_m.update(acc5.item(), output.size(0))
1166 |
1167 | batch_time_m.update(time.time() - end)
1168 | end = time.time()
1169 | if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
1170 | log_name = 'Test' + log_suffix
1171 | _logger.info(
1172 | f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
1173 | f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
1174 | f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
1175 | f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
1176 | f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
1177 | )
1178 |
1179 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1180 |
1181 | return metrics
1182 |
1183 |
1184 | if __name__ == '__main__':
1185 | main()
1186 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ ImageNet Validation Script
3 |
4 | This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
5 | models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
6 | canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
7 |
8 | Hacked together by Ross Wightman (https://github.com/rwightman)
9 | """
10 | import argparse
11 | import csv
12 | import glob
13 | import json
14 | import logging
15 | import os
16 | import time
17 | from collections import OrderedDict
18 | from contextlib import suppress
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.parallel
24 |
25 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
26 | from timm.layers import apply_test_time_pool, set_fast_norm
27 | from timm.models import create_model, load_checkpoint, is_model, list_models
28 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
29 | decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
30 |
31 | try:
32 | from apex import amp
33 | has_apex = True
34 | except ImportError:
35 | has_apex = False
36 |
37 | has_native_amp = False
38 | try:
39 | if getattr(torch.cuda.amp, 'autocast') is not None:
40 | has_native_amp = True
41 | except AttributeError:
42 | pass
43 |
44 | try:
45 | from functorch.compile import memory_efficient_fusion
46 | has_functorch = True
47 | except ImportError as e:
48 | has_functorch = False
49 |
50 | import katransformer
51 |
52 | has_compile = hasattr(torch, 'compile')
53 |
54 | _logger = logging.getLogger('validate')
55 |
56 |
57 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
58 | parser.add_argument('data', nargs='?', metavar='DIR', const=None,
59 | help='path to dataset (*deprecated*, use --data-dir)')
60 | parser.add_argument('--data-dir', metavar='DIR',
61 | help='path to dataset (root dir)')
62 | parser.add_argument('--dataset', metavar='NAME', default='',
63 | help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)')
64 | parser.add_argument('--split', metavar='NAME', default='validation',
65 | help='dataset split (default: validation)')
66 | parser.add_argument('--num-samples', default=None, type=int,
67 | metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.')
68 | parser.add_argument('--dataset-download', action='store_true', default=False,
69 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
70 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
71 | help='path to class to idx mapping file (default: "")')
72 | parser.add_argument('--input-key', default=None, type=str,
73 | help='Dataset key for input images.')
74 | parser.add_argument('--input-img-mode', default=None, type=str,
75 | help='Dataset image conversion mode for input images.')
76 | parser.add_argument('--target-key', default=None, type=str,
77 | help='Dataset key for target labels.')
78 |
79 | parser.add_argument('--model', '-m', metavar='NAME', default='kat_tiny_patch16_224',
80 | help='model architecture (default: kat_tiny_patch16_224)')
81 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
82 | help='use pre-trained model')
83 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
84 | help='number of data loading workers (default: 4)')
85 | parser.add_argument('-b', '--batch-size', default=256, type=int,
86 | metavar='N', help='mini-batch size (default: 256)')
87 | parser.add_argument('--img-size', default=None, type=int,
88 | metavar='N', help='Input image dimension, uses model default if empty')
89 | parser.add_argument('--in-chans', type=int, default=None, metavar='N',
90 | help='Image input channels (default: None => 3)')
91 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
92 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
93 | parser.add_argument('--use-train-size', action='store_true', default=False,
94 | help='force use of train input size, even when test size is specified in pretrained cfg')
95 | parser.add_argument('--crop-pct', default=None, type=float,
96 | metavar='N', help='Input image center crop pct')
97 | parser.add_argument('--crop-mode', default=None, type=str,
98 | metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
99 | parser.add_argument('--crop-border-pixels', type=int, default=None,
100 | help='Crop pixels from image border.')
101 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
102 | help='Override mean pixel value of dataset')
103 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
104 | help='Override std deviation of dataset')
105 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
106 | help='Image resize interpolation type (overrides model)')
107 | parser.add_argument('--num-classes', type=int, default=None,
108 | help='Number classes in dataset')
109 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
110 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
111 | parser.add_argument('--log-freq', default=10, type=int,
112 | metavar='N', help='batch logging frequency (default: 10)')
113 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
114 | help='path to latest checkpoint (default: none)')
115 | parser.add_argument('--num-gpu', type=int, default=1,
116 | help='Number of GPUS to use')
117 | parser.add_argument('--test-pool', dest='test_pool', action='store_true',
118 | help='enable test time pool')
119 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
120 | help='disable fast prefetcher')
121 | parser.add_argument('--pin-mem', action='store_true', default=False,
122 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
123 | parser.add_argument('--channels-last', action='store_true', default=False,
124 | help='Use channels_last memory layout')
125 | parser.add_argument('--device', default='cuda', type=str,
126 | help="Device (accelerator) to use.")
127 | parser.add_argument('--amp', action='store_true', default=False,
128 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
129 | parser.add_argument('--amp-dtype', default='float16', type=str,
130 | help='lower precision AMP dtype (default: float16)')
131 | parser.add_argument('--amp-impl', default='native', type=str,
132 | help='AMP impl to use, "native" or "apex" (default: native)')
133 | parser.add_argument('--tf-preprocessing', action='store_true', default=False,
134 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
135 | parser.add_argument('--use-ema', dest='use_ema', action='store_true',
136 | help='use ema version of weights if present')
137 | parser.add_argument('--fuser', default='', type=str,
138 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
139 | parser.add_argument('--fast-norm', default=False, action='store_true',
140 | help='enable experimental fast-norm')
141 | parser.add_argument('--reparam', default=False, action='store_true',
142 | help='Reparameterize model')
143 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
144 |
145 |
146 | scripting_group = parser.add_mutually_exclusive_group()
147 | scripting_group.add_argument('--torchscript', default=False, action='store_true',
148 | help='torch.jit.script the full model')
149 | scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
150 | help="Enable compilation w/ specified backend (default: inductor).")
151 | scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
152 | help="Enable AOT Autograd support.")
153 |
154 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
155 | help='Output csv file for validation results (summary)')
156 | parser.add_argument('--results-format', default='csv', type=str,
157 | help='Format for results file one of (csv, json) (default: csv).')
158 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
159 | help='Real labels JSON file for imagenet evaluation')
160 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
161 | help='Valid label indices txt file for validation of partial label space')
162 | parser.add_argument('--retry', default=False, action='store_true',
163 | help='Enable batch size decay & retry for single model validation')
164 |
165 |
166 | def validate(args):
167 | # might as well try to validate something
168 | args.pretrained = args.pretrained or not args.checkpoint
169 | args.prefetcher = not args.no_prefetcher
170 |
171 | if torch.cuda.is_available():
172 | torch.backends.cuda.matmul.allow_tf32 = True
173 | torch.backends.cudnn.benchmark = True
174 |
175 | device = torch.device(args.device)
176 |
177 | # resolve AMP arguments based on PyTorch / Apex availability
178 | use_amp = None
179 | amp_autocast = suppress
180 | if args.amp:
181 | if args.amp_impl == 'apex':
182 | assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
183 | assert args.amp_dtype == 'float16'
184 | use_amp = 'apex'
185 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
186 | else:
187 | assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
188 | assert args.amp_dtype in ('float16', 'bfloat16')
189 | use_amp = 'native'
190 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
191 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
192 | _logger.info('Validating in mixed precision with native PyTorch AMP.')
193 | else:
194 | _logger.info('Validating in float32. AMP not enabled.')
195 |
196 | if args.fuser:
197 | set_jit_fuser(args.fuser)
198 |
199 | if args.fast_norm:
200 | set_fast_norm()
201 |
202 | # create model
203 | in_chans = 3
204 | if args.in_chans is not None:
205 | in_chans = args.in_chans
206 | elif args.input_size is not None:
207 | in_chans = args.input_size[0]
208 |
209 | model = create_model(
210 | args.model,
211 | pretrained=args.pretrained,
212 | num_classes=args.num_classes,
213 | in_chans=in_chans,
214 | global_pool=args.gp,
215 | scriptable=args.torchscript,
216 | **args.model_kwargs,
217 | )
218 | if args.num_classes is None:
219 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
220 | args.num_classes = model.num_classes
221 |
222 | if args.checkpoint:
223 | load_checkpoint(model, args.checkpoint, args.use_ema)
224 |
225 | if args.reparam:
226 | model = reparameterize_model(model)
227 |
228 | param_count = sum([m.numel() for m in model.parameters()])
229 | _logger.info('Model %s created, param count: %d' % (args.model, param_count))
230 |
231 | data_config = resolve_data_config(
232 | vars(args),
233 | model=model,
234 | use_test_size=not args.use_train_size,
235 | verbose=True,
236 | )
237 | test_time_pool = False
238 | if args.test_pool:
239 | model, test_time_pool = apply_test_time_pool(model, data_config)
240 |
241 | model = model.to(device)
242 | if args.channels_last:
243 | model = model.to(memory_format=torch.channels_last)
244 |
245 | if args.torchscript:
246 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
247 | model = torch.jit.script(model)
248 | elif args.torchcompile:
249 | assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
250 | torch._dynamo.reset()
251 | model = torch.compile(model, backend=args.torchcompile)
252 | elif args.aot_autograd:
253 | assert has_functorch, "functorch is needed for --aot-autograd"
254 | model = memory_efficient_fusion(model)
255 |
256 | if use_amp == 'apex':
257 | model = amp.initialize(model, opt_level='O1')
258 |
259 | if args.num_gpu > 1:
260 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
261 |
262 | criterion = nn.CrossEntropyLoss().to(device)
263 |
264 | root_dir = args.data or args.data_dir
265 | if args.input_img_mode is None:
266 | input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
267 | else:
268 | input_img_mode = args.input_img_mode
269 | dataset = create_dataset(
270 | root=root_dir,
271 | name=args.dataset,
272 | split=args.split,
273 | download=args.dataset_download,
274 | load_bytes=args.tf_preprocessing,
275 | class_map=args.class_map,
276 | num_samples=args.num_samples,
277 | input_key=args.input_key,
278 | input_img_mode=input_img_mode,
279 | target_key=args.target_key,
280 | )
281 |
282 | if args.valid_labels:
283 | with open(args.valid_labels, 'r') as f:
284 | valid_labels = [int(line.rstrip()) for line in f]
285 | else:
286 | valid_labels = None
287 |
288 | if args.real_labels:
289 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
290 | else:
291 | real_labels = None
292 |
293 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
294 | loader = create_loader(
295 | dataset,
296 | input_size=data_config['input_size'],
297 | batch_size=args.batch_size,
298 | use_prefetcher=args.prefetcher,
299 | interpolation=data_config['interpolation'],
300 | mean=data_config['mean'],
301 | std=data_config['std'],
302 | num_workers=args.workers,
303 | crop_pct=crop_pct,
304 | crop_mode=data_config['crop_mode'],
305 | crop_border_pixels=args.crop_border_pixels,
306 | pin_memory=args.pin_mem,
307 | device=device,
308 | tf_preprocessing=args.tf_preprocessing,
309 | )
310 |
311 | batch_time = AverageMeter()
312 | losses = AverageMeter()
313 | top1 = AverageMeter()
314 | top5 = AverageMeter()
315 |
316 | model.eval()
317 | with torch.no_grad():
318 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
319 | input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
320 | if args.channels_last:
321 | input = input.contiguous(memory_format=torch.channels_last)
322 | with amp_autocast():
323 | model(input)
324 |
325 | end = time.time()
326 | for batch_idx, (input, target) in enumerate(loader):
327 | if args.no_prefetcher:
328 | target = target.to(device)
329 | input = input.to(device)
330 | if args.channels_last:
331 | input = input.contiguous(memory_format=torch.channels_last)
332 |
333 | # compute output
334 | with amp_autocast():
335 | output = model(input)
336 |
337 | if valid_labels is not None:
338 | output = output[:, valid_labels]
339 | loss = criterion(output, target)
340 |
341 | if real_labels is not None:
342 | real_labels.add_result(output)
343 |
344 | # measure accuracy and record loss
345 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
346 | losses.update(loss.item(), input.size(0))
347 | top1.update(acc1.item(), input.size(0))
348 | top5.update(acc5.item(), input.size(0))
349 |
350 | # measure elapsed time
351 | batch_time.update(time.time() - end)
352 | end = time.time()
353 |
354 | if batch_idx % args.log_freq == 0:
355 | _logger.info(
356 | 'Test: [{0:>4d}/{1}] '
357 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
358 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
359 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
360 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
361 | batch_idx,
362 | len(loader),
363 | batch_time=batch_time,
364 | rate_avg=input.size(0) / batch_time.avg,
365 | loss=losses,
366 | top1=top1,
367 | top5=top5
368 | )
369 | )
370 |
371 | if real_labels is not None:
372 | # real labels mode replaces topk values at the end
373 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
374 | else:
375 | top1a, top5a = top1.avg, top5.avg
376 | results = OrderedDict(
377 | model=args.model,
378 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
379 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
380 | param_count=round(param_count / 1e6, 2),
381 | img_size=data_config['input_size'][-1],
382 | crop_pct=crop_pct,
383 | interpolation=data_config['interpolation'],
384 | )
385 |
386 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
387 | results['top1'], results['top1_err'], results['top5'], results['top5_err']))
388 |
389 | return results
390 |
391 |
392 | def _try_run(args, initial_batch_size):
393 | batch_size = initial_batch_size
394 | results = OrderedDict()
395 | error_str = 'Unknown'
396 | while batch_size:
397 | args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
398 | try:
399 | if torch.cuda.is_available() and 'cuda' in args.device:
400 | torch.cuda.empty_cache()
401 | results = validate(args)
402 | return results
403 | except RuntimeError as e:
404 | error_str = str(e)
405 | _logger.error(f'"{error_str}" while running validation.')
406 | if not check_batch_size_retry(error_str):
407 | break
408 | batch_size = decay_batch_step(batch_size)
409 | _logger.warning(f'Reducing batch size to {batch_size} for retry.')
410 | results['error'] = error_str
411 | _logger.error(f'{args.model} failed to validate ({error_str}).')
412 | return results
413 |
414 |
415 | _NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']
416 |
417 |
418 | def main():
419 | setup_default_logging()
420 | args = parser.parse_args()
421 | model_cfgs = []
422 | model_names = []
423 | if os.path.isdir(args.checkpoint):
424 | # validate all checkpoints in a path with same model
425 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
426 | checkpoints += glob.glob(args.checkpoint + '/*.pth')
427 | model_names = list_models(args.model)
428 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
429 | else:
430 | if args.model == 'all':
431 | # validate all models in a list of names with pretrained checkpoints
432 | args.pretrained = True
433 | model_names = list_models(
434 | pretrained=True,
435 | exclude_filters=_NON_IN1K_FILTERS,
436 | )
437 | model_cfgs = [(n, '') for n in model_names]
438 | elif not is_model(args.model):
439 | # model name doesn't exist, try as wildcard filter
440 | model_names = list_models(
441 | args.model,
442 | pretrained=True,
443 | )
444 | model_cfgs = [(n, '') for n in model_names]
445 |
446 | if not model_cfgs and os.path.isfile(args.model):
447 | with open(args.model) as f:
448 | model_names = [line.rstrip() for line in f]
449 | model_cfgs = [(n, None) for n in model_names if n]
450 |
451 | if len(model_cfgs):
452 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
453 | results = []
454 | try:
455 | initial_batch_size = args.batch_size
456 | for m, c in model_cfgs:
457 | args.model = m
458 | args.checkpoint = c
459 | r = _try_run(args, initial_batch_size)
460 | if 'error' in r:
461 | continue
462 | if args.checkpoint:
463 | r['checkpoint'] = args.checkpoint
464 | results.append(r)
465 | except KeyboardInterrupt as e:
466 | pass
467 | results = sorted(results, key=lambda x: x['top1'], reverse=True)
468 | else:
469 | if args.retry:
470 | results = _try_run(args, args.batch_size)
471 | else:
472 | results = validate(args)
473 |
474 | if args.results_file:
475 | write_results(args.results_file, results, format=args.results_format)
476 |
477 | # output results in JSON to stdout w/ delimiter for runner script
478 | print(f'--result\n{json.dumps(results, indent=4)}')
479 |
480 |
481 | def write_results(results_file, results, format='csv'):
482 | with open(results_file, mode='w') as cf:
483 | if format == 'json':
484 | json.dump(results, cf, indent=4)
485 | else:
486 | if not isinstance(results, (list, tuple)):
487 | results = [results]
488 | if not results:
489 | return
490 | dw = csv.DictWriter(cf, fieldnames=results[0].keys())
491 | dw.writeheader()
492 | for r in results:
493 | dw.writerow(r)
494 | cf.flush()
495 |
496 |
497 |
498 | if __name__ == '__main__':
499 | main()
500 |
--------------------------------------------------------------------------------