├── LICENSE ├── README.md ├── assets ├── KAT.png ├── kat3-1.png └── logo.webp ├── dist_train.sh ├── example.py ├── katransformer.py ├── scripts ├── train_kat_base_8x128.sh ├── train_kat_base_8x128_vitft.sh ├── train_kat_small_8x128.sh ├── train_kat_small_8x128_vitft.sh ├── train_kat_tiny_8x128.sh └── train_kat_tiny_8x128_vitft.sh ├── tools └── calculate_flops.py ├── train.py └── validate.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xingyi Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | 7 |

Kolmogorov–Arnold Transformer:
A PyTorch Implementation

8 |
9 | 10 |

11 | 12 | 13 | Tested PyTorch Versions 14 | License 15 |

16 |

17 | ICLR 2025 18 |

19 | 20 |

21 |
22 | Yes, I kan! 23 |

24 | 25 | 🎉 This is a PyTorch/GPU implementation of the paper **Kolmogorov–Arnold Transformer (KAT)**, which replace the MLP layers in transformer with KAN layers. 26 | 27 | For more technical details, please refer to our ICLR'25 paper. 28 | 29 | > **Kolmogorov–Arnold Transformer** 30 | > 📝[[Paper](https://arxiv.org/abs/2409.10594)] [[code](https://github.com/Adamdad/kat)] [[Trition/CUDA kernel](https://github.com/Adamdad/rational_kat_cu)] 31 | > [Xingyi Yang](https://adamdad.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) 32 | > National University of Singapore 33 | > International Conference on Learning Representations (**ICLR'25**) 34 | 35 | ### 🔑 Key Insight: 36 | 37 | Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements. 38 | 39 |

40 |
41 |

42 | 43 | ### 🎯 Our Solutions: 44 | 1. **Base Function**: Replace B-spline to CUDA-implemented Rational. 45 | 2. **Group KAN**: Share weights among groups of edges for efficiency. 46 | 3. **Initialization**: Maintain activation magnitudes across layers. 47 | 48 | ### ✅ Updates 49 | - [x] Release the KAT paper, CUDA implementation and IN-1k training code. 50 | - [x] 🎉🎉🎉🎉 Triton Implementation, on 1D and 2D tasks. This is much easier to install than the CUDA version. Please See [https://github.com/Adamdad/rational_kat_cu](https://github.com/Adamdad/rational_kat_cu). 51 | - [ ] KAT Detection and segmentation code. 52 | - [ ] KAT on NLP tasks. 53 | 54 | ## 🛠️ Installation and Dataset 55 | Please find our CUDA implementation in [https://github.com/Adamdad/rational_kat_cu.git](https://github.com/Adamdad/rational_kat_cu.git). 56 | ```shell 57 | # install torch and other things 58 | pip install timm==1.0.3 59 | pip install wandb # I personally use wandb for results visualizations 60 | git clone https://github.com/Adamdad/rational_kat_cu.git 61 | cd rational_kat_cu 62 | pip install -e . 63 | ``` 64 | 65 | 📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4) 66 | 67 | ``` 68 | │imagenet/ 69 | ├──train/ 70 | │ ├── n01440764 71 | │ │ ├── n01440764_10026.JPEG 72 | │ │ ├── n01440764_10027.JPEG 73 | │ │ ├── ...... 74 | │ ├── ...... 75 | ├──val/ 76 | │ ├── n01440764 77 | │ │ ├── ILSVRC2012_val_00000293.JPEG 78 | │ │ ├── ILSVRC2012_val_00002138.JPEG 79 | │ │ ├── ...... 80 | │ ├── ...... 81 | ``` 82 | 83 | ## Usage 84 | 85 | Refer to `example.py` for a detailed use case demonstrating how to use KAT with timm to classify an image. 86 | 87 | ## 📊 Model Checkpoints 88 | Download pre-trained models or access training checkpoints: 89 | 90 | |🏷️ Model |⚙️ Setup |📦 Param| 📈 Top1 |🔗 Link| 91 | | ---|---|---| ---|---| 92 | |KAT-T| From Scratch|5.7M | 74.6| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224) 93 | |KAT-T | From ViT | 5.7M | 75.7| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth)/[huggingface](https://huggingface.co/adamdad/kat_tiny_patch16_224.vitft) 94 | |KAT-S| From Scratch| 22.1M | 81.2| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224) 95 | |KAT-S | From ViT |22.1M | 82.0| [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth)/[huggingface](https://huggingface.co/adamdad/kat_small_patch16_224.vitft) 96 | | KAT-B| From Scratch |86.6M| 82.3 | [link](https://github.com/Adamdad/kat/releases/download/checkpoint/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224) 97 | | KAT-B | From ViT |86.6M| 82.8 | [link](https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth)/[huggingface](https://huggingface.co/adamdad/kat_base_patch16_224.vitft)| 98 | 99 | ## 🎓Model Training 100 | 101 | All training scripts are under `scripts/` 102 | ```shell 103 | bash scripts/train_kat_tiny_8x128.sh 104 | ``` 105 | 106 | If you want to change the hyper-parameters, can edit 107 | ```shell 108 | #!/bin/bash 109 | DATA_PATH=/local_home/dataset/imagenet/ 110 | 111 | bash ./dist_train.sh 8 $DATA_PATH \ 112 | --model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions 113 | -b 128 \ 114 | --opt adamw \ 115 | --lr 1e-3 \ 116 | --weight-decay 0.05 \ 117 | --epochs 300 \ 118 | --mixup 0.8 \ 119 | --cutmix 1.0 \ 120 | --sched cosine \ 121 | --smoothing 0.1 \ 122 | --drop-path 0.1 \ 123 | --aa rand-m9-mstd0.5 \ 124 | --remode pixel --reprob 0.25 \ 125 | --amp \ 126 | --crop-pct 0.875 \ 127 | --mean 0.485 0.456 0.406 \ 128 | --std 0.229 0.224 0.225 \ 129 | --model-ema \ 130 | --model-ema-decay 0.9999 \ 131 | --output output/kat_tiny_swish_patch16_224 \ 132 | --log-wandb 133 | ``` 134 | 135 | ## 🧪 Evaluation 136 | To evaluate our `kat_tiny_patch16_224` models, run: 137 | 138 | ```shell 139 | DATA_PATH=/local_home/dataset/imagenet/ 140 | CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth 141 | python validate.py $DATA_PATH --model kat_tiny_patch16_224 \ 142 | --checkpoint $CHECKPOINT_PATH -b 512 143 | 144 | ################### 145 | Validating in float32. AMP not enabled. 146 | Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth' 147 | Model kat_tiny_patch16_224 created, param count: 5718328 148 | Data processing configuration for current model + dataset: 149 | input_size: (3, 224, 224) 150 | interpolation: bicubic 151 | mean: (0.485, 0.456, 0.406) 152 | std: (0.229, 0.224, 0.225) 153 | crop_pct: 0.875 154 | crop_mode: center 155 | Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875) 156 | ....... 157 | Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220) 158 | * Acc@1 74.558 (25.442) Acc@5 92.390 (7.610) 159 | --result 160 | { 161 | "model": "kat_tiny_patch16_224", 162 | "top1": 74.558, 163 | "top1_err": 25.442, 164 | "top5": 92.39, 165 | "top5_err": 7.61, 166 | "param_count": 5.72, 167 | "img_size": 224, 168 | "crop_pct": 0.875, 169 | "interpolation": "bicubic" 170 | } 171 | ``` 172 | 173 | 174 | ## 🙏 Acknowledgments 175 | We extend our gratitude to the authors of [rational_activations](https://github.com/ml-research/rational_activations) for their contributions to CUDA rational function implementations that inspired parts of this work. We thank [@yuweihao](https://github.com/yuweihao), [@florinshen](https://github.com/florinshen), [@Huage001](https://github.com/Huage001) and [@yu-rp](https://github.com/yu-rp) for valuable discussions. 176 | 177 | ## 📚 Bibtex 178 | If you use this repository, please cite: 179 | ```bibtex 180 | @inproceedings{ 181 | yang2025kolmogorovarnold, 182 | title={Kolmogorov-Arnold Transformer}, 183 | author={Xingyi Yang, Xinchao Wang}, 184 | booktitle={The Thirteenth International Conference on Learning Representations}, 185 | year={2025}, 186 | url={https://openreview.net/forum?id=BCeock53nt} 187 | } 188 | ``` 189 | -------------------------------------------------------------------------------- /assets/KAT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/KAT.png -------------------------------------------------------------------------------- /assets/kat3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/kat3-1.png -------------------------------------------------------------------------------- /assets/logo.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Adamdad/kat/d254de7c14b6c050bd00cac3689b0a5614659a7f/assets/logo.webp -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | torchrun --nproc_per_node=$NUM_PROC train.py "$@" -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from urllib.request import urlopen 2 | from PIL import Image 3 | import timm 4 | import torch 5 | import json 6 | import katransformer 7 | 8 | # Load the image 9 | img = Image.open(urlopen( 10 | 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png' 11 | )) 12 | 13 | # Move model to CUDA 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | # Load the pre-trained KAT model 17 | model = timm.create_model('hf_hub:adamdad/kat_tiny_patch16_224.vitft', pretrained=True) 18 | model = model.to(device) 19 | model = model.eval() 20 | 21 | # Get model-specific transforms (normalization, resize) 22 | data_config = timm.data.resolve_model_data_config(model) 23 | transforms = timm.data.create_transform(**data_config, is_training=False) 24 | 25 | # Preprocess image and make predictions 26 | output = model(transforms(img).unsqueeze(0).to(device)) # unsqueeze single image into batch of 1 27 | top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5) 28 | 29 | # Load ImageNet class names 30 | imagenet_classes_url = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json' 31 | class_idx = json.load(urlopen(imagenet_classes_url)) 32 | 33 | # Map class indices to class names 34 | top5_class_names = [class_idx[idx] for idx in top5_class_indices[0].tolist()] 35 | 36 | # Print top 5 probabilities and corresponding class names 37 | print("Top-5 Class Names:", top5_class_names) 38 | print("Top-5 Probabilities:", top5_probabilities) 39 | -------------------------------------------------------------------------------- /katransformer.py: -------------------------------------------------------------------------------- 1 | """ Kolmogorov-Arnold Transformer (KAT) model 2 | author = "Xingyi Yang" 3 | """ 4 | import logging 5 | import math 6 | from collections import OrderedDict 7 | from functools import partial 8 | from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List 9 | try: 10 | from typing import Literal 11 | except ImportError: 12 | from typing_extensions import Literal 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.utils.checkpoint 18 | from torch.jit import Final 19 | 20 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ 21 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD 22 | from timm.layers import PatchEmbed, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ 23 | trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ 24 | get_act_layer, get_norm_layer, LayerType 25 | from timm.models._builder import build_model_with_cfg 26 | from timm.models._features import feature_take_indices 27 | from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv 28 | from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations 29 | from timm.models.layers import to_2tuple 30 | 31 | import numpy as np 32 | 33 | 34 | __all__ = ['KAT'] # model_registry will add each entrypoint fn to this 35 | 36 | 37 | _logger = logging.getLogger(__name__) 38 | 39 | import sys 40 | sys.path.insert(0, 'rational_kat_cu') 41 | from kat_rational import KAT_Group 42 | 43 | 44 | def calculate_gain(nonlinearity, param=None): 45 | r"""Return the recommended gain value for the given nonlinearity function. 46 | The values are as follows: 47 | 48 | ================= ==================================================== 49 | nonlinearity gain 50 | ================= ==================================================== 51 | Linear / Identity :math:`1` 52 | Conv{1,2,3}D :math:`1` 53 | Sigmoid :math:`1` 54 | Tanh :math:`\frac{5}{3}` 55 | ReLU :math:`\sqrt{2}` 56 | Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 57 | SELU :math:`\frac{3}{4}` 58 | ================= ==================================================== 59 | 60 | .. warning:: 61 | In order to implement `Self-Normalizing Neural Networks`_ , 62 | you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. 63 | This gives the initial weights a variance of ``1 / N``, 64 | which is necessary to induce a stable fixed point in the forward pass. 65 | In contrast, the default gain for ``SELU`` sacrifices the normalisation 66 | effect for more stable gradient flow in rectangular layers. 67 | 68 | Args: 69 | nonlinearity: the non-linear function (`nn.functional` name) 70 | param: optional parameter for the non-linear function 71 | 72 | Examples: 73 | >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 74 | 75 | .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html 76 | """ 77 | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] 78 | if nonlinearity in linear_fns or nonlinearity == 'sigmoid': 79 | return 1 80 | elif nonlinearity == 'tanh': 81 | return 5.0 / 3 82 | elif nonlinearity == 'relu': 83 | return math.sqrt(2.0) 84 | elif nonlinearity == 'gelu': 85 | return math.sqrt(2.3567850379928976) 86 | elif nonlinearity == 'silu' or nonlinearity == 'swish': 87 | return math.sqrt(2.8178205270359653) 88 | elif nonlinearity == 'leaky_relu': 89 | if param is None: 90 | negative_slope = 0.01 91 | elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): 92 | # True/False are instances of int, hence check above 93 | negative_slope = param 94 | else: 95 | raise ValueError("negative_slope {} not a valid number".format(param)) 96 | return math.sqrt(2.0 / (1 + negative_slope ** 2)) 97 | elif nonlinearity == 'selu': 98 | return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) 99 | else: 100 | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) 101 | 102 | torch.nn.init.calculate_gain = calculate_gain 103 | 104 | 105 | class KAN(nn.Module): 106 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 107 | """ 108 | def __init__( 109 | self, 110 | in_features, 111 | hidden_features=None, 112 | out_features=None, 113 | act_layer=KAT_Group, 114 | norm_layer=None, 115 | bias=True, 116 | drop=0., 117 | use_conv=False, 118 | act_init="gelu", 119 | device=None 120 | ): 121 | super().__init__() 122 | if device is None: 123 | device = "cuda" if torch.cuda.is_available() else "cpu" 124 | out_features = out_features or in_features 125 | hidden_features = hidden_features or in_features 126 | bias = to_2tuple(bias) 127 | drop_probs = to_2tuple(drop) 128 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 129 | 130 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 131 | self.act1 = KAT_Group(mode="identity", device=device) 132 | self.drop1 = nn.Dropout(drop_probs[0]) 133 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 134 | self.act2 = KAT_Group(mode=act_init, device=device) 135 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 136 | self.drop2 = nn.Dropout(drop_probs[1]) 137 | 138 | def forward(self, x): 139 | x = self.act1(x) 140 | x = self.drop1(x) 141 | x = self.fc1(x) 142 | x = self.act2(x) 143 | x = self.drop2(x) 144 | x = self.fc2(x) 145 | return x 146 | 147 | def get_ortho_like(dim, alpha, beta, sign=1, dist='uniform'): 148 | """ 149 | Generate an orthogonal-like matrix with specified dimensions and properties. 150 | 151 | Args: 152 | dim (int): The dimension of the matrix. 153 | alpha (float): The scaling factor for the random matrix. 154 | beta (float): The scaling factor for the identity matrix. 155 | sign (int, optional): The sign of the identity matrix. Defaults to 1. 156 | dist (str, optional): The distribution of the random matrix. 157 | Can be either 'normal' or 'uniform'. Defaults to 'uniform'. 158 | 159 | Returns: 160 | tuple: A tuple containing the left and right orthogonal-like matrices. 161 | """ 162 | if dist == 'normal': 163 | A = alpha * np.random.normal(size=(dim,dim)) / (dim**0.5) + sign * beta * np.eye(dim) 164 | if dist == 'uniform': 165 | A = alpha * np.random.uniform(size=(dim,dim), low=-3**0.5 / (dim**0.5), high = 3**0.5 / (dim**0.5)) + sign * beta * np.eye(dim) 166 | 167 | U, S, V = np.linalg.svd(A) 168 | L = U @ np.diag(np.sqrt(S)) 169 | R = np.diag(np.sqrt(S)) @ V 170 | return L, R 171 | 172 | 173 | def get_ortho_like_gaussian2d(dim, alpha, beta, sigma=1.0, sign=1, dist='uniform'): 174 | if dist == 'normal': 175 | A = alpha * np.random.normal(size=(dim, dim)) / (dim**0.5) 176 | elif dist == 'uniform': 177 | A = alpha * np.random.uniform(low=-3**0.5 / (dim**0.5), high=3**0.5 / (dim**0.5), size=(dim, dim)) 178 | 179 | size = int(np.sqrt(dim)) 180 | gauss_diag = beta * gaussian_correlation_matrix((size, size), sigma=sigma) 181 | gauss_diag = sign * gauss_diag 182 | 183 | A += gauss_diag 184 | 185 | # Compute the SVD of the updated matrix A 186 | U, S, V = np.linalg.svd(A) 187 | 188 | # Create the left and right matrices 189 | L = U @ np.diag(np.sqrt(S)) 190 | R = np.diag(np.sqrt(S)) @ V 191 | 192 | return L, R 193 | 194 | 195 | def gaussian_correlation_matrix(image_shape, sigma=1.0): 196 | """ 197 | Generate a Gaussian correlation matrix for a flattened 2D image. 198 | 199 | Args: 200 | image_shape (tuple): The shape of the image (height, width). 201 | sigma (float): The standard deviation of the Gaussian distribution. 202 | 203 | Returns: 204 | np.array: A 2D array where each entry represents the correlation 205 | between two flattened image pixels. 206 | """ 207 | # Create an array of all pixel coordinates 208 | y, x = np.indices(image_shape) 209 | 210 | # Flatten the coordinate arrays 211 | y = y.ravel() 212 | x = x.ravel() 213 | 214 | # Compute distances between every pair of pixels 215 | distance_matrix = np.sqrt((y[:, np.newaxis] - y[np.newaxis, :])**2 + 216 | (x[:, np.newaxis] - x[np.newaxis, :])**2) 217 | 218 | # Compute the Gaussian correlation matrix using the distance matrix 219 | correlation_matrix = np.exp(-distance_matrix**2 / (2 * sigma**2)) 220 | 221 | return correlation_matrix 222 | 223 | class Attention(nn.Module): 224 | fused_attn: Final[bool] 225 | 226 | def __init__( 227 | self, 228 | dim: int, 229 | num_heads: int = 8, 230 | qkv_bias: bool = False, 231 | qk_norm: bool = False, 232 | attn_drop: float = 0., 233 | proj_drop: float = 0., 234 | norm_layer: nn.Module = nn.LayerNorm, 235 | ) -> None: 236 | super().__init__() 237 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 238 | self.num_heads = num_heads 239 | self.head_dim = dim // num_heads 240 | self.scale = self.head_dim ** -0.5 241 | self.fused_attn = use_fused_attn() 242 | 243 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 244 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 245 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 246 | self.attn_drop = nn.Dropout(attn_drop) 247 | self.proj = nn.Linear(dim, dim) 248 | self.proj_drop = nn.Dropout(proj_drop) 249 | 250 | def forward(self, x: torch.Tensor) -> torch.Tensor: 251 | B, N, C = x.shape 252 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 253 | q, k, v = qkv.unbind(0) 254 | q, k = self.q_norm(q), self.k_norm(k) 255 | 256 | if self.fused_attn: 257 | x = F.scaled_dot_product_attention( 258 | q, k, v, 259 | dropout_p=self.attn_drop.p if self.training else 0., 260 | ) 261 | else: 262 | q = q * self.scale 263 | attn = q @ k.transpose(-2, -1) 264 | attn = attn.softmax(dim=-1) 265 | attn = self.attn_drop(attn) 266 | x = attn @ v 267 | 268 | x = x.transpose(1, 2).reshape(B, N, C) 269 | x = self.proj(x) 270 | x = self.proj_drop(x) 271 | return x 272 | 273 | class LayerScale(nn.Module): 274 | def __init__( 275 | self, 276 | dim: int, 277 | init_values: float = 1e-5, 278 | inplace: bool = False, 279 | ) -> None: 280 | super().__init__() 281 | self.inplace = inplace 282 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 283 | 284 | def forward(self, x: torch.Tensor) -> torch.Tensor: 285 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 286 | 287 | 288 | class Block(nn.Module): 289 | def __init__( 290 | self, 291 | dim: int, 292 | num_heads: int, 293 | mlp_ratio: float = 4., 294 | qkv_bias: bool = False, 295 | qk_norm: bool = False, 296 | proj_drop: float = 0., 297 | attn_drop: float = 0., 298 | init_values: Optional[float] = None, 299 | drop_path: float = 0., 300 | act_layer: nn.Module = nn.GELU, 301 | norm_layer: nn.Module = nn.LayerNorm, 302 | mlp_layer: nn.Module = KAN, 303 | act_init: str = 'gelu', 304 | ) -> None: 305 | super().__init__() 306 | self.norm1 = norm_layer(dim) 307 | self.attn = Attention( 308 | dim, 309 | num_heads=num_heads, 310 | qkv_bias=qkv_bias, 311 | qk_norm=qk_norm, 312 | attn_drop=attn_drop, 313 | proj_drop=proj_drop, 314 | norm_layer=norm_layer, 315 | ) 316 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 317 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 318 | 319 | self.norm2 = norm_layer(dim) 320 | self.mlp = mlp_layer( 321 | in_features=dim, 322 | hidden_features=int(dim * mlp_ratio), 323 | act_layer=act_layer, 324 | drop=proj_drop, 325 | act_init=act_init, 326 | ) 327 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 328 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 329 | 330 | def forward(self, x: torch.Tensor) -> torch.Tensor: 331 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 332 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 333 | return x 334 | 335 | 336 | 337 | class KATVisionTransformer(nn.Module): 338 | """ Vision Transformer 339 | 340 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 341 | - https://arxiv.org/abs/2010.11929 342 | """ 343 | dynamic_img_size: Final[bool] 344 | 345 | def __init__( 346 | self, 347 | img_size: Union[int, Tuple[int, int]] = 224, 348 | patch_size: Union[int, Tuple[int, int]] = 16, 349 | in_chans: int = 3, 350 | num_classes: int = 1000, 351 | global_pool: Literal['', 'avg', 'token', 'map'] = 'token', 352 | embed_dim: int = 768, 353 | depth: int = 12, 354 | num_heads: int = 12, 355 | mlp_ratio: float = 4., 356 | qkv_bias: bool = True, 357 | qk_norm: bool = False, 358 | init_values: Optional[float] = None, 359 | class_token: bool = True, 360 | pos_embed: str = 'learn', 361 | no_embed_class: bool = False, 362 | reg_tokens: int = 0, 363 | pre_norm: bool = False, 364 | fc_norm: Optional[bool] = None, 365 | dynamic_img_size: bool = False, 366 | dynamic_img_pad: bool = False, 367 | drop_rate: float = 0., 368 | pos_drop_rate: float = 0., 369 | patch_drop_rate: float = 0., 370 | proj_drop_rate: float = 0., 371 | attn_drop_rate: float = 0., 372 | drop_path_rate: float = 0., 373 | weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', '', 'kan', 'kan_mimetic', 'mimetic', 'kan_mimetic2d'] = '', 374 | fix_init: bool = False, 375 | embed_layer: Callable = PatchEmbed, 376 | norm_layer: Optional[LayerType] = None, 377 | act_layer: Optional[LayerType] = None, # None, 378 | block_fn: Type[nn.Module] = Block, 379 | mlp_layer: Type[nn.Module] = KAN, 380 | act_init: str = 'gelu', 381 | ) -> None: 382 | """ 383 | Args: 384 | img_size: Input image size. 385 | patch_size: Patch size. 386 | in_chans: Number of image input channels. 387 | num_classes: Mumber of classes for classification head. 388 | global_pool: Type of global pooling for final sequence (default: 'token'). 389 | embed_dim: Transformer embedding dimension. 390 | depth: Depth of transformer. 391 | num_heads: Number of attention heads. 392 | mlp_ratio: Ratio of mlp hidden dim to embedding dim. 393 | qkv_bias: Enable bias for qkv projections if True. 394 | init_values: Layer-scale init values (layer-scale enabled if not None). 395 | class_token: Use class token. 396 | no_embed_class: Don't include position embeddings for class (or reg) tokens. 397 | reg_tokens: Number of register tokens. 398 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. 399 | drop_rate: Head dropout rate. 400 | pos_drop_rate: Position embedding dropout rate. 401 | attn_drop_rate: Attention dropout rate. 402 | drop_path_rate: Stochastic depth rate. 403 | weight_init: Weight initialization scheme. 404 | fix_init: Apply weight initialization fix (scaling w/ layer index). 405 | embed_layer: Patch embedding layer. 406 | norm_layer: Normalization layer. 407 | act_layer: MLP activation layer. 408 | block_fn: Transformer block layer. 409 | """ 410 | super().__init__() 411 | assert global_pool in ('', 'avg', 'token', 'map') 412 | assert class_token or global_pool != 'token' 413 | assert pos_embed in ('', 'none', 'learn') 414 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 415 | norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) 416 | act_layer = get_act_layer(act_layer) or nn.GELU 417 | 418 | self.num_classes = num_classes 419 | self.global_pool = global_pool 420 | self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models 421 | self.num_prefix_tokens = 1 if class_token else 0 422 | self.num_prefix_tokens += reg_tokens 423 | self.num_reg_tokens = reg_tokens 424 | self.has_class_token = class_token 425 | self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) 426 | self.dynamic_img_size = dynamic_img_size 427 | self.grad_checkpointing = False 428 | self.act_init = act_init 429 | 430 | embed_args = {} 431 | if dynamic_img_size: 432 | # flatten deferred until after pos embed 433 | embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) 434 | self.patch_embed = embed_layer( 435 | img_size=img_size, 436 | patch_size=patch_size, 437 | in_chans=in_chans, 438 | embed_dim=embed_dim, 439 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 440 | dynamic_img_pad=dynamic_img_pad, 441 | **embed_args, 442 | ) 443 | num_patches = self.patch_embed.num_patches 444 | reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size 445 | 446 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 447 | self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None 448 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 449 | if not pos_embed or pos_embed == 'none': 450 | self.pos_embed = None 451 | else: 452 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 453 | self.pos_drop = nn.Dropout(p=pos_drop_rate) 454 | if patch_drop_rate > 0: 455 | self.patch_drop = PatchDropout( 456 | patch_drop_rate, 457 | num_prefix_tokens=self.num_prefix_tokens, 458 | ) 459 | else: 460 | self.patch_drop = nn.Identity() 461 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 462 | 463 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 464 | self.blocks = nn.Sequential(*[ 465 | block_fn( 466 | dim=embed_dim, 467 | num_heads=num_heads, 468 | mlp_ratio=mlp_ratio, 469 | qkv_bias=qkv_bias, 470 | qk_norm=qk_norm, 471 | init_values=init_values, 472 | proj_drop=proj_drop_rate, 473 | attn_drop=attn_drop_rate, 474 | drop_path=dpr[i], 475 | norm_layer=norm_layer, 476 | act_layer=act_layer, 477 | mlp_layer=mlp_layer, 478 | act_init=act_init, 479 | ) 480 | for i in range(depth)]) 481 | self.feature_info = [ 482 | dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] 483 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 484 | 485 | # Classifier Head 486 | if global_pool == 'map': 487 | self.attn_pool = AttentionPoolLatent( 488 | self.embed_dim, 489 | num_heads=num_heads, 490 | mlp_ratio=mlp_ratio, 491 | norm_layer=norm_layer, 492 | ) 493 | else: 494 | self.attn_pool = None 495 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 496 | self.head_drop = nn.Dropout(drop_rate) 497 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 498 | 499 | if weight_init != 'skip': 500 | self.init_weights(weight_init) 501 | if fix_init: 502 | self.fix_init_weight() 503 | 504 | def fix_init_weight(self): 505 | def rescale(param, _layer_id): 506 | param.div_(math.sqrt(2.0 * _layer_id)) 507 | 508 | for layer_id, layer in enumerate(self.blocks): 509 | rescale(layer.attn.proj.weight.data, layer_id + 1) 510 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 511 | 512 | def init_weights(self, mode: str = '') -> None: 513 | assert mode in ('jax', 'jax_nlhb', 'moco', '', 'kan', 'kan_mimetic', 'mimetic') 514 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 515 | if self.pos_embed is not None: 516 | trunc_normal_(self.pos_embed, std=.02) 517 | if self.cls_token is not None: 518 | nn.init.normal_(self.cls_token, std=1e-6) 519 | 520 | mode = mode + f'_{self.act_init}' if 'kan' in mode else mode 521 | named_apply(get_init_weights_vit(mode, head_bias), self) 522 | 523 | if 'mimetic' in mode: 524 | named_apply(init_weights_attn_mimetic, self) 525 | # named_apply(init_weights_attn_mimetic, self) 526 | 527 | def _init_weights(self, m: nn.Module) -> None: 528 | # this fn left here for compat with downstream users 529 | init_weights_vit_timm(m) 530 | 531 | @torch.jit.ignore() 532 | def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: 533 | if checkpoint_path.endswith('.npz'): 534 | # load from .npz checkpoint 535 | _load_weights(self, checkpoint_path, prefix) 536 | else: 537 | print(f'Loading model weights from: {checkpoint_path}') 538 | # load from .pth checkpoint 539 | state_dict = torch.load(checkpoint_path, map_location='cpu')['model'] 540 | if 'state_dict' in state_dict: 541 | state_dict = state_dict['state_dict'] 542 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} 543 | state_dict['pos_embed'] = state_dict['pos_embed'][:, 1:] 544 | msg = self.load_state_dict(state_dict, strict=False) 545 | print(msg) 546 | 547 | @torch.jit.ignore 548 | def no_weight_decay(self) -> Set: 549 | return {'pos_embed', 'cls_token', 'dist_token'} 550 | 551 | @torch.jit.ignore 552 | def group_matcher(self, coarse: bool = False) -> Dict: 553 | return dict( 554 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 555 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 556 | ) 557 | 558 | @torch.jit.ignore 559 | def set_grad_checkpointing(self, enable: bool = True) -> None: 560 | self.grad_checkpointing = enable 561 | if hasattr(self.patch_embed, 'set_grad_checkpointing'): 562 | self.patch_embed.set_grad_checkpointing(enable) 563 | 564 | @torch.jit.ignore 565 | def get_classifier(self) -> nn.Module: 566 | return self.head 567 | 568 | def reset_classifier(self, num_classes: int, global_pool = None) -> None: 569 | self.num_classes = num_classes 570 | if global_pool is not None: 571 | assert global_pool in ('', 'avg', 'token', 'map') 572 | if global_pool == 'map' and self.attn_pool is None: 573 | assert False, "Cannot currently add attention pooling in reset_classifier()." 574 | elif global_pool != 'map ' and self.attn_pool is not None: 575 | self.attn_pool = None # remove attention pooling 576 | self.global_pool = global_pool 577 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 578 | 579 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: 580 | if self.pos_embed is None: 581 | return x.view(x.shape[0], -1, x.shape[-1]) 582 | 583 | if self.dynamic_img_size: 584 | B, H, W, C = x.shape 585 | pos_embed = resample_abs_pos_embed( 586 | self.pos_embed, 587 | (H, W), 588 | num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, 589 | ) 590 | x = x.view(B, -1, C) 591 | else: 592 | pos_embed = self.pos_embed 593 | 594 | to_cat = [] 595 | if self.cls_token is not None: 596 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) 597 | if self.reg_token is not None: 598 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) 599 | 600 | if self.no_embed_class: 601 | # deit-3, updated JAX (big vision) 602 | # position embedding does not overlap with class token, add then concat 603 | x = x + pos_embed 604 | if to_cat: 605 | x = torch.cat(to_cat + [x], dim=1) 606 | else: 607 | # original timm, JAX, and deit vit impl 608 | # pos_embed has entry for class token, concat then add 609 | if to_cat: 610 | x = torch.cat(to_cat + [x], dim=1) 611 | x = x + pos_embed 612 | 613 | return self.pos_drop(x) 614 | 615 | def forward_intermediates( 616 | self, 617 | x: torch.Tensor, 618 | indices: Optional[Union[int, List[int], Tuple[int]]] = None, 619 | return_prefix_tokens: bool = False, 620 | norm: bool = False, 621 | stop_early: bool = False, 622 | output_fmt: str = 'NCHW', 623 | intermediates_only: bool = False, 624 | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: 625 | """ Forward features that returns intermediates. 626 | 627 | Args: 628 | x: Input image tensor 629 | indices: Take last n blocks if int, all if None, select matching indices if sequence 630 | return_prefix_tokens: Return both prefix and spatial intermediate tokens 631 | norm: Apply norm layer to all intermediates 632 | stop_early: Stop iterating over blocks when last desired intermediate hit 633 | output_fmt: Shape of intermediate feature outputs 634 | intermediates_only: Only return intermediate features 635 | Returns: 636 | 637 | """ 638 | assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' 639 | reshape = output_fmt == 'NCHW' 640 | intermediates = [] 641 | take_indices, max_index = feature_take_indices(len(self.blocks), indices) 642 | 643 | # forward pass 644 | B, _, height, width = x.shape 645 | x = self.patch_embed(x) 646 | x = self._pos_embed(x) 647 | x = self.patch_drop(x) 648 | x = self.norm_pre(x) 649 | 650 | if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript 651 | blocks = self.blocks 652 | else: 653 | blocks = self.blocks[:max_index + 1] 654 | for i, blk in enumerate(blocks): 655 | x = blk(x) 656 | if i in take_indices: 657 | # normalize intermediates with final norm layer if enabled 658 | intermediates.append(self.norm(x) if norm else x) 659 | 660 | # process intermediates 661 | if self.num_prefix_tokens: 662 | # split prefix (e.g. class, distill) and spatial feature tokens 663 | prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] 664 | intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] 665 | if reshape: 666 | # reshape to BCHW output format 667 | H, W = self.patch_embed.dynamic_feat_size((height, width)) 668 | intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] 669 | if not torch.jit.is_scripting() and return_prefix_tokens: 670 | # return_prefix not support in torchscript due to poor type handling 671 | intermediates = list(zip(intermediates, prefix_tokens)) 672 | 673 | if intermediates_only: 674 | return intermediates 675 | 676 | x = self.norm(x) 677 | 678 | return x, intermediates 679 | 680 | def prune_intermediate_layers( 681 | self, 682 | indices: Union[int, List[int], Tuple[int]] = 1, 683 | prune_norm: bool = False, 684 | prune_head: bool = True, 685 | ): 686 | """ Prune layers not required for specified intermediates. 687 | """ 688 | take_indices, max_index = feature_take_indices(len(self.blocks), indices) 689 | self.blocks = self.blocks[:max_index + 1] # truncate blocks 690 | if prune_norm: 691 | self.norm = nn.Identity() 692 | if prune_head: 693 | self.fc_norm = nn.Identity() 694 | self.reset_classifier(0, '') 695 | return take_indices 696 | 697 | def get_intermediate_layers( 698 | self, 699 | x: torch.Tensor, 700 | n: Union[int, List[int], Tuple[int]] = 1, 701 | reshape: bool = False, 702 | return_prefix_tokens: bool = False, 703 | norm: bool = False, 704 | ) -> List[torch.Tensor]: 705 | """ Intermediate layer accessor inspired by DINO / DINOv2 interface. 706 | NOTE: This API is for backwards compat, favour using forward_intermediates() directly. 707 | """ 708 | return self.forward_intermediates( 709 | x, n, 710 | return_prefix_tokens=return_prefix_tokens, 711 | norm=norm, 712 | output_fmt='NCHW' if reshape else 'NLC', 713 | intermediates_only=True, 714 | ) 715 | 716 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: 717 | x = self.patch_embed(x) 718 | x = self._pos_embed(x) 719 | x = self.patch_drop(x) 720 | x = self.norm_pre(x) 721 | if self.grad_checkpointing and not torch.jit.is_scripting(): 722 | x = checkpoint_seq(self.blocks, x) 723 | else: 724 | x = self.blocks(x) 725 | x = self.norm(x) 726 | return x 727 | 728 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: 729 | if self.attn_pool is not None: 730 | x = self.attn_pool(x) 731 | elif self.global_pool == 'avg': 732 | x = x[:, self.num_prefix_tokens:].mean(dim=1) 733 | elif self.global_pool: 734 | x = x[:, 0] # class token 735 | x = self.fc_norm(x) 736 | x = self.head_drop(x) 737 | return x if pre_logits else self.head(x) 738 | 739 | def forward(self, x: torch.Tensor) -> torch.Tensor: 740 | x = self.forward_features(x) 741 | x = self.forward_head(x) 742 | return x 743 | 744 | 745 | def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: 746 | """ ViT weight initialization, original timm impl (for reproducibility) """ 747 | if isinstance(module, nn.Linear): 748 | trunc_normal_(module.weight, std=.02) 749 | if module.bias is not None: 750 | nn.init.zeros_(module.bias) 751 | elif hasattr(module, 'init_weights'): 752 | module.init_weights() 753 | 754 | 755 | def init_weights_vit_kan_gelu(module: nn.Module, name: str = '') -> None: 756 | """ ViT weight initialization, original timm impl (for reproducibility) """ 757 | if isinstance(module, nn.Linear): 758 | if 'mlp' in name: 759 | if 'fc1' in name: 760 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='linear') 761 | elif 'fc2' in name: 762 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='gelu') 763 | 764 | # nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 765 | 766 | if 'qkv' in name: 767 | # treat the weights of Q, K, V separately 768 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 769 | nn.init.uniform_(module.weight, -val, val) 770 | else: 771 | trunc_normal_(module.weight, std=.02) 772 | if module.bias is not None: 773 | nn.init.zeros_(module.bias) 774 | elif hasattr(module, 'init_weights'): 775 | module.init_weights() 776 | 777 | def init_weights_vit_kan_swish(module: nn.Module, name: str = '') -> None: 778 | """ ViT weight initialization, original timm impl (for reproducibility) """ 779 | if isinstance(module, nn.Linear): 780 | if 'mlp' in name: 781 | if 'fc1' in name: 782 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='linear') 783 | elif 'fc2' in name: 784 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='swish') 785 | if 'qkv' in name: 786 | # treat the weights of Q, K, V separately 787 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 788 | nn.init.uniform_(module.weight, -val, val) 789 | else: 790 | trunc_normal_(module.weight, std=.02) 791 | if module.bias is not None: 792 | nn.init.zeros_(module.bias) 793 | elif hasattr(module, 'init_weights'): 794 | module.init_weights() 795 | 796 | def init_weights_attn_mimetic(module: nn.Module, name: str = '') -> None: 797 | """ ViT weight initialization, original timm impl (for reproducibility) """ 798 | if isinstance(module, Attention): 799 | alpha1 = 0.7 800 | beta1 = 0.7 801 | alpha2 = 0.4 802 | beta2 = 0.4 803 | head_dim = module.head_dim 804 | embed_dim = module.head_dim * module.num_heads 805 | 806 | for h in range(module.num_heads): 807 | Q, K = get_ortho_like(embed_dim, alpha1, beta1, 1) 808 | Q = Q[:,:head_dim] 809 | K = K.T[:,:head_dim] 810 | 811 | module.qkv.weight.data[(h*head_dim):((h+1)*head_dim)] = torch.tensor(Q.T).float() 812 | module.qkv.weight.data[embed_dim+(h*head_dim):embed_dim+((h+1)*head_dim)] = torch.tensor(K.T).float() 813 | 814 | V, Proj = get_ortho_like(embed_dim, alpha2, beta2, -1) 815 | module.qkv.weight.data[2*embed_dim:] = torch.tensor(V).float() 816 | module.proj.weight.data = torch.tensor(Proj).float() 817 | 818 | 819 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None: 820 | """ ViT weight initialization, matching JAX (Flax) impl """ 821 | if isinstance(module, nn.Linear): 822 | if name.startswith('head'): 823 | nn.init.zeros_(module.weight) 824 | nn.init.constant_(module.bias, head_bias) 825 | else: 826 | nn.init.xavier_uniform_(module.weight) 827 | if module.bias is not None: 828 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 829 | elif isinstance(module, nn.Conv2d): 830 | lecun_normal_(module.weight) 831 | if module.bias is not None: 832 | nn.init.zeros_(module.bias) 833 | elif hasattr(module, 'init_weights'): 834 | module.init_weights() 835 | 836 | 837 | def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: 838 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 839 | if isinstance(module, nn.Linear): 840 | if 'qkv' in name: 841 | # treat the weights of Q, K, V separately 842 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 843 | nn.init.uniform_(module.weight, -val, val) 844 | else: 845 | nn.init.xavier_uniform_(module.weight) 846 | if module.bias is not None: 847 | nn.init.zeros_(module.bias) 848 | elif hasattr(module, 'init_weights'): 849 | module.init_weights() 850 | 851 | 852 | def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: 853 | if 'jax' in mode: 854 | return partial(init_weights_vit_jax, head_bias=head_bias) 855 | elif 'moco' in mode: 856 | return init_weights_vit_moco 857 | elif 'kan' in mode: 858 | if 'gelu' in mode: 859 | return init_weights_vit_kan_gelu 860 | elif 'swish' in mode: 861 | return init_weights_vit_kan_swish 862 | else: 863 | AssertionError(f'Unknown mode {mode}') 864 | else: 865 | return init_weights_vit_timm 866 | 867 | 868 | def resize_pos_embed( 869 | posemb: torch.Tensor, 870 | posemb_new: torch.Tensor, 871 | num_prefix_tokens: int = 1, 872 | gs_new: Tuple[int, int] = (), 873 | interpolation: str = 'bicubic', 874 | antialias: bool = False, 875 | ) -> torch.Tensor: 876 | """ Rescale the grid of position embeddings when loading from state_dict. 877 | *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed 878 | """ 879 | ntok_new = posemb_new.shape[1] - num_prefix_tokens 880 | ntok_old = posemb.shape[1] - num_prefix_tokens 881 | gs_old = [int(math.sqrt(ntok_old))] * 2 882 | if not len(gs_new): # backwards compatibility 883 | gs_new = [int(math.sqrt(ntok_new))] * 2 884 | return resample_abs_pos_embed( 885 | posemb, gs_new, gs_old, 886 | num_prefix_tokens=num_prefix_tokens, 887 | interpolation=interpolation, 888 | antialias=antialias, 889 | verbose=True, 890 | ) 891 | 892 | 893 | @torch.no_grad() 894 | def _load_weights(model: KATVisionTransformer, checkpoint_path: str, prefix: str = '') -> None: 895 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 896 | """ 897 | import numpy as np 898 | 899 | def _n2p(w, t=True, idx=None): 900 | if idx is not None: 901 | w = w[idx] 902 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 903 | w = w.flatten() 904 | if t: 905 | if w.ndim == 4: 906 | w = w.transpose([3, 2, 0, 1]) 907 | elif w.ndim == 3: 908 | w = w.transpose([2, 0, 1]) 909 | elif w.ndim == 2: 910 | w = w.transpose([1, 0]) 911 | return torch.from_numpy(w) 912 | 913 | w = np.load(checkpoint_path) 914 | interpolation = 'bilinear' 915 | antialias = False 916 | big_vision = False 917 | if not prefix: 918 | if 'opt/target/embedding/kernel' in w: 919 | prefix = 'opt/target/' 920 | elif 'params/embedding/kernel' in w: 921 | prefix = 'params/' 922 | big_vision = True 923 | elif 'params/img/embedding/kernel' in w: 924 | prefix = 'params/img/' 925 | big_vision = True 926 | 927 | if hasattr(model.patch_embed, 'backbone'): 928 | # hybrid 929 | backbone = model.patch_embed.backbone 930 | stem_only = not hasattr(backbone, 'stem') 931 | stem = backbone if stem_only else backbone.stem 932 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 933 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 934 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 935 | if not stem_only: 936 | for i, stage in enumerate(backbone.stages): 937 | for j, block in enumerate(stage.blocks): 938 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 939 | for r in range(3): 940 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 941 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 942 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 943 | if block.downsample is not None: 944 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 945 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 946 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 947 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 948 | else: 949 | embed_conv_w = adapt_input_conv( 950 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 951 | if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: 952 | embed_conv_w = resample_patch_embed( 953 | embed_conv_w, 954 | model.patch_embed.proj.weight.shape[-2:], 955 | interpolation=interpolation, 956 | antialias=antialias, 957 | verbose=True, 958 | ) 959 | 960 | model.patch_embed.proj.weight.copy_(embed_conv_w) 961 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 962 | if model.cls_token is not None: 963 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 964 | if big_vision: 965 | pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) 966 | else: 967 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 968 | if pos_embed_w.shape != model.pos_embed.shape: 969 | old_shape = pos_embed_w.shape 970 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) 971 | pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights 972 | pos_embed_w, 973 | new_size=model.patch_embed.grid_size, 974 | num_prefix_tokens=num_prefix_tokens, 975 | interpolation=interpolation, 976 | antialias=antialias, 977 | verbose=True, 978 | ) 979 | model.pos_embed.copy_(pos_embed_w) 980 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 981 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 982 | if (isinstance(model.head, nn.Linear) and 983 | f'{prefix}head/bias' in w and 984 | model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]): 985 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 986 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 987 | # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights 988 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 989 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 990 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 991 | if model.attn_pool is not None: 992 | block_prefix = f'{prefix}MAPHead_0/' 993 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' 994 | model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) 995 | model.attn_pool.kv.weight.copy_(torch.cat([ 996 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) 997 | model.attn_pool.kv.bias.copy_(torch.cat([ 998 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) 999 | model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) 1000 | model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) 1001 | model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 1002 | model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 1003 | model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 1004 | model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 1005 | for r in range(2): 1006 | getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) 1007 | getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) 1008 | 1009 | mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) 1010 | for i, block in enumerate(model.blocks.children()): 1011 | if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: 1012 | block_prefix = f'{prefix}Transformer/encoderblock/' 1013 | idx = i 1014 | else: 1015 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 1016 | idx = None 1017 | mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' 1018 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) 1019 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) 1020 | block.attn.qkv.weight.copy_(torch.cat([ 1021 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) 1022 | block.attn.qkv.bias.copy_(torch.cat([ 1023 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) 1024 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) 1025 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) 1026 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) 1027 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) 1028 | for r in range(2): 1029 | getattr(block.mlp, f'fc{r + 1}').weight.copy_( 1030 | _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) 1031 | getattr(block.mlp, f'fc{r + 1}').bias.copy_( 1032 | _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) 1033 | 1034 | 1035 | def _convert_openai_clip( 1036 | state_dict: Dict[str, torch.Tensor], 1037 | model: KATVisionTransformer, 1038 | prefix: str = 'visual.', 1039 | ) -> Dict[str, torch.Tensor]: 1040 | out_dict = {} 1041 | swaps = [ 1042 | ('conv1', 'patch_embed.proj'), 1043 | ('positional_embedding', 'pos_embed'), 1044 | ('transformer.resblocks.', 'blocks.'), 1045 | ('ln_pre', 'norm_pre'), 1046 | ('ln_post', 'norm'), 1047 | ('ln_', 'norm'), 1048 | ('in_proj_', 'qkv.'), 1049 | ('out_proj', 'proj'), 1050 | ('mlp.c_fc', 'mlp.fc1'), 1051 | ('mlp.c_proj', 'mlp.fc2'), 1052 | ] 1053 | for k, v in state_dict.items(): 1054 | if not k.startswith(prefix): 1055 | continue 1056 | k = k.replace(prefix, '') 1057 | for sp in swaps: 1058 | k = k.replace(sp[0], sp[1]) 1059 | 1060 | if k == 'proj': 1061 | k = 'head.weight' 1062 | v = v.transpose(0, 1) 1063 | out_dict['head.bias'] = torch.zeros(v.shape[0]) 1064 | elif k == 'class_embedding': 1065 | k = 'cls_token' 1066 | v = v.unsqueeze(0).unsqueeze(1) 1067 | elif k == 'pos_embed': 1068 | v = v.unsqueeze(0) 1069 | out_dict[k] = v 1070 | return out_dict 1071 | 1072 | 1073 | def _convert_dinov2( 1074 | state_dict: Dict[str, torch.Tensor], 1075 | model: KATVisionTransformer, 1076 | ) -> Dict[str, torch.Tensor]: 1077 | import re 1078 | out_dict = {} 1079 | state_dict.pop("mask_token", None) 1080 | if 'register_tokens' in state_dict: 1081 | # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) 1082 | out_dict['reg_token'] = state_dict.pop('register_tokens') 1083 | out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0] 1084 | out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:] 1085 | for k, v in state_dict.items(): 1086 | if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): 1087 | out_dict[k.replace("w12", "fc1")] = v 1088 | continue 1089 | elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): 1090 | out_dict[k.replace("w3", "fc2")] = v 1091 | continue 1092 | out_dict[k] = v 1093 | return out_dict 1094 | 1095 | 1096 | def checkpoint_filter_fn( 1097 | state_dict: Dict[str, torch.Tensor], 1098 | model: KATVisionTransformer, 1099 | adapt_layer_scale: bool = False, 1100 | interpolation: str = 'bicubic', 1101 | antialias: bool = True, 1102 | ) -> Dict[str, torch.Tensor]: 1103 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 1104 | import re 1105 | out_dict = {} 1106 | state_dict = state_dict.get('model', state_dict) 1107 | state_dict = state_dict.get('state_dict', state_dict) 1108 | prefix = '' 1109 | 1110 | if 'visual.class_embedding' in state_dict: 1111 | state_dict = _convert_openai_clip(state_dict, model) 1112 | elif 'module.visual.class_embedding' in state_dict: 1113 | state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') 1114 | elif "mask_token" in state_dict: 1115 | state_dict = _convert_dinov2(state_dict, model) 1116 | elif "encoder" in state_dict: 1117 | # IJEPA, vit in an 'encoder' submodule 1118 | state_dict = state_dict['encoder'] 1119 | prefix = 'module.' 1120 | elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict: 1121 | # OpenCLIP model with timm vision encoder 1122 | prefix = 'visual.trunk.' 1123 | if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear): 1124 | # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) 1125 | out_dict['head.weight'] = state_dict['visual.head.proj.weight'] 1126 | out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) 1127 | 1128 | if prefix: 1129 | # filter on & remove prefix string from keys 1130 | state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} 1131 | 1132 | for k, v in state_dict.items(): 1133 | if 'patch_embed.proj.weight' in k: 1134 | O, I, H, W = model.patch_embed.proj.weight.shape 1135 | if len(v.shape) < 4: 1136 | # For old models that I trained prior to conv based patchification 1137 | O, I, H, W = model.patch_embed.proj.weight.shape 1138 | v = v.reshape(O, -1, H, W) 1139 | if v.shape[-1] != W or v.shape[-2] != H: 1140 | v = resample_patch_embed( 1141 | v, 1142 | (H, W), 1143 | interpolation=interpolation, 1144 | antialias=antialias, 1145 | verbose=True, 1146 | ) 1147 | elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: 1148 | # To resize pos embedding when using model at different size from pretrained weights 1149 | num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) 1150 | v = resample_abs_pos_embed( 1151 | v, 1152 | new_size=model.patch_embed.grid_size, 1153 | num_prefix_tokens=num_prefix_tokens, 1154 | interpolation=interpolation, 1155 | antialias=antialias, 1156 | verbose=True, 1157 | ) 1158 | elif adapt_layer_scale and 'gamma_' in k: 1159 | # remap layer-scale gamma into sub-module (deit3 models) 1160 | k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) 1161 | elif 'pre_logits' in k: 1162 | # NOTE representation layer removed as not used in latest 21k/1k pretrained weights 1163 | continue 1164 | out_dict[k] = v 1165 | return out_dict 1166 | 1167 | 1168 | def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 1169 | return { 1170 | 'url': url, 1171 | 'num_classes': 1000, 1172 | 'input_size': (3, 224, 224), 1173 | 'pool_size': None, 1174 | 'crop_pct': 0.875, 1175 | 'interpolation': 'bicubic', 1176 | 'fixed_input_size': True, 1177 | 'mean': IMAGENET_DEFAULT_MEAN, 1178 | 'std': IMAGENET_DEFAULT_STD, 1179 | 'first_conv': 'patch_embed.proj', 1180 | 'classifier': 'head', 1181 | **kwargs, 1182 | } 1183 | 1184 | default_cfgs = { 1185 | 'kat_tiny_patch16_224': _cfg( 1186 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'), 1187 | 'kat_tiny_patch16_224.vitft': _cfg( 1188 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_tiny_patch16_224-finetune_64f124d003803e4a7e1aba1ba23500ace359b544e8a5f0110993f25052e402fb.pth'), 1189 | 'kat_small_patch16_224': _cfg( 1190 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_small_patch16_224_32487885cf13d2c14e461c9016fac8ad43f7c769171f132530941e930aeb5fe2.pth'), 1191 | 'kat_small_patch16_224.vitft': _cfg( 1192 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_small_patch_224-finetune_3ae087a4c28e2993468eb377d5151350c52c80b2a70cc48ceec63d1328ba58e0.pth'), 1193 | 'kat_base_patch16_224': _cfg( 1194 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224_abff874d925d756d15cde97303f772a3460ddbd44b9c53fb9ce5cf15be230fb6.pth'), 1195 | 'kat_base_patch16_224.vitft': _cfg( 1196 | url='https://huggingface.co/adamdad/kat_pretained/resolve/main/kat_base_patch16_224-finetune_440bf1ead9dd8ecab642078cfb60ae542f1fa33ca65517260501e02c011e38f2.pth') 1197 | } 1198 | 1199 | 1200 | def _create_kat_transformer(variant: str, pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1201 | out_indices = kwargs.pop('out_indices', 3) 1202 | if 'flexi' in variant: 1203 | # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed 1204 | # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. 1205 | _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) 1206 | else: 1207 | _filter_fn = checkpoint_filter_fn 1208 | 1209 | # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? 1210 | strict = True 1211 | if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': 1212 | strict = False 1213 | 1214 | return build_model_with_cfg( 1215 | KATVisionTransformer, 1216 | variant, 1217 | pretrained, 1218 | pretrained_filter_fn=_filter_fn, 1219 | pretrained_strict=strict, 1220 | feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), 1221 | **kwargs, 1222 | ) 1223 | 1224 | @register_model 1225 | def kat_tiny_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1226 | """ KAT-Tiny with rational activations (ViT-S/16) 1227 | """ 1228 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, 1229 | act_layer=KAT_Group, 1230 | act_init='swish', 1231 | mlp_layer=KAN, 1232 | weight_init="kan_mimetic") 1233 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1234 | return model 1235 | 1236 | @register_model 1237 | def kat_tiny_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1238 | """ KAT-Tiny with rational activations (ViT-S/16) 1239 | """ 1240 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, 1241 | act_layer=KAT_Group, 1242 | act_init='gelu', 1243 | mlp_layer=KAN, 1244 | weight_init="kan_mimetic") 1245 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1246 | return model 1247 | 1248 | @register_model 1249 | def kat_tiny_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1250 | """ KAT-Tiny with rational activations (ViT-Ti/16) 1251 | """ 1252 | model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, 1253 | act_layer=KAT_Group, 1254 | act_init='swish', 1255 | mlp_layer=KAN, 1256 | weight_init="kan_mimetic") 1257 | model = _create_kat_transformer('kat_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1258 | return model 1259 | 1260 | @register_model 1261 | def kat_small_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1262 | """ KAT-Small with rational activations (ViT-S/16)""" 1263 | model_args = dict(patch_size=16, 1264 | embed_dim=384, 1265 | depth=12, 1266 | num_heads=6, 1267 | act_init='swish', 1268 | act_layer=KAT_Group, 1269 | mlp_layer=KAN, 1270 | weight_init="kan_mimetic") 1271 | model = _create_kat_transformer('kat_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1272 | return model 1273 | 1274 | @register_model 1275 | def kat_small_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1276 | """ KAT-Small with rational activations (ViT-S/16)""" 1277 | model_args = dict(patch_size=16, 1278 | embed_dim=384, 1279 | depth=12, 1280 | num_heads=6, 1281 | act_init='gelu', 1282 | act_layer=KAT_Group, 1283 | mlp_layer=KAN, weight_init="kan_mimetic") # , init_values=1e-5 1284 | model = _create_kat_transformer('kat_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1285 | return model 1286 | 1287 | @register_model 1288 | def kat_small_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1289 | """ KAT-Small with rational activations (ViT-S/16)""" 1290 | model_args = dict(patch_size=16, 1291 | embed_dim=384, 1292 | depth=12, num_heads=6, 1293 | act_init='swish', 1294 | act_layer=KAT_Group, 1295 | mlp_layer=KAN, weight_init="kan_mimetic") # , init_values=1e-5 1296 | model = _create_kat_transformer('kat', pretrained=pretrained, **dict(model_args, **kwargs)) 1297 | return model 1298 | 1299 | @register_model 1300 | def kat_base_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1301 | """ KAT-Base with rational activations (ViT-B/16)""" 1302 | model_args = dict(patch_size=16, 1303 | embed_dim=768, 1304 | depth=12, 1305 | num_heads=12, 1306 | act_layer=KAT_Group, 1307 | mlp_layer=KAN, 1308 | act_init='swish', 1309 | weight_init="kan_mimetic") 1310 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1311 | return model 1312 | 1313 | @register_model 1314 | def kat_base_gelu_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1315 | """ KAT-Base with rational activations (ViT-B/16)""" 1316 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, 1317 | act_layer=KAT_Group, 1318 | mlp_layer=KAN, 1319 | act_init='gelu', 1320 | weight_init="kan_mimetic") 1321 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1322 | return model 1323 | 1324 | @register_model 1325 | def kat_base_swish_patch16_224(pretrained: bool = False, **kwargs) -> KATVisionTransformer: 1326 | """ KAT-Base with rational activations (ViT-B/16)""" 1327 | model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, 1328 | act_layer=KAT_Group, 1329 | mlp_layer=KAN, 1330 | act_init='swish', 1331 | weight_init="kan_mimetic") 1332 | model = _create_kat_transformer('kat_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) 1333 | return model 1334 | -------------------------------------------------------------------------------- /scripts/train_kat_base_8x128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_base_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.4 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_base_swish_patch16_224 \ 25 | --log-wandb -------------------------------------------------------------------------------- /scripts/train_kat_base_8x128_vitft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_base_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.4 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_base_swish_patch16_224 \ 25 | --initial-checkpoint ./checkpoints/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz \ 26 | --log-wandb -------------------------------------------------------------------------------- /scripts/train_kat_small_8x128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_small_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.1 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_small_swish_patch16_224 \ 25 | --log-wandb -------------------------------------------------------------------------------- /scripts/train_kat_small_8x128_vitft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_small_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.1 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_small_swish_patch16_224 \ 25 | --pretrained-path ./checkpoints/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz \ 26 | --pretrained \ 27 | --log-wandb -------------------------------------------------------------------------------- /scripts/train_kat_tiny_8x128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_tiny_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.1 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_tiny_swish_patch16_224 \ 25 | --log-wandb 26 | -------------------------------------------------------------------------------- /scripts/train_kat_tiny_8x128_vitft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_PATH=/local_home/dataset/imagenet/ 3 | 4 | bash ./dist_train.sh 8 $DATA_PATH \ 5 | --model kat_tiny_swish_patch16_224 \ 6 | -b 128 \ 7 | --opt adamw \ 8 | --lr 1e-3 \ 9 | --weight-decay 0.05 \ 10 | --epochs 300 \ 11 | --mixup 0.8 \ 12 | --cutmix 1.0 \ 13 | --sched cosine \ 14 | --smoothing 0.1 \ 15 | --drop-path 0.1 \ 16 | --aa rand-m9-mstd0.5 \ 17 | --remode pixel --reprob 0.25 \ 18 | --amp \ 19 | --crop-pct 0.875 \ 20 | --mean 0.485 0.456 0.406 \ 21 | --std 0.229 0.224 0.225 \ 22 | --model-ema \ 23 | --model-ema-decay 0.9999 \ 24 | --output output/kat_tiny_swish_patch16_224 \ 25 | --initial-checkpoint ./checkpoints/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz \ 26 | --log-wandb -------------------------------------------------------------------------------- /tools/calculate_flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import katransformer 4 | import timm 5 | from fvcore.nn import FlopCountAnalysis, flop_count_str 6 | from fvcore.nn.jit_handles import get_shape 7 | from math import prod 8 | 9 | 10 | def compute_flops(model_name, input_size, custom_op_name, device='cuda'): 11 | """ 12 | Computes the FLOPs for a given TIMM model, including a customized operator, using FlopCountAnalysis. 13 | 14 | Args: 15 | model_name (str): The name of the TIMM model. 16 | input_size (tuple): The input size (height, width) of the model. 17 | custom_op_name (str): The name of the customized operator. 18 | custom_op_flops (int): The number of FLOPs for the customized operator. 19 | 20 | Returns: 21 | int: The total number of FLOPs. 22 | """ 23 | 24 | # Load the TIMM model 25 | model = timm.create_model(model_name, pretrained=False) 26 | 27 | # Ensure the model is in evaluation mode to avoid unnecessary computations 28 | model.eval() 29 | model.to(device) 30 | 31 | # Create a dummy input tensor with the specified size 32 | input_tensor = torch.randn(1, 3, *input_size).to(device) 33 | 34 | # Set the FLOPs for the customized operator 35 | def _custom_op_flops_fn(inputs, outputs): 36 | # Assuming each operation in custom_op involves 'n' flops per input element 37 | n = 21 # This should be adjusted based on what the custom operation does 38 | 39 | input_shape = get_shape(inputs[0]) 40 | total_elements = prod(input_shape) 41 | return total_elements * n 42 | 43 | 44 | # Use FlopCountAnalysis to compute the FLOPs 45 | analysis = FlopCountAnalysis(model, 46 | input_tensor) 47 | analysis.set_op_handle(custom_op_name, _custom_op_flops_fn) 48 | 49 | print(flop_count_str(analysis)) 50 | totoal_flops = analysis.total() 51 | # print totoal_flops in GigaFlops 52 | print("Total FLOPs: ", totoal_flops/1e9, "GFLOPs") 53 | total_params = sum(p.numel() for p in model.parameters()) 54 | print("Total Params: ", total_params/1e6, "M") 55 | 56 | return totoal_flops / 1e9 57 | 58 | 59 | # return int(flops) 60 | 61 | if __name__ == "__main__": 62 | # model_name = "vit_base_kat_mimetic_patch16_224" # Replace with your desired model name 63 | model_name = 'kat_tiny_gelu_patch16_224' # Replace with your desired model name 64 | input_size = (224, 224) # Replace with your desired input size 65 | custom_op_name = "prim::PythonOp.rational_1dgroup" # Replace with your customized operator name 66 | # custom_op_flops = 1000 # Replace with the actual FLOPs of your customized operator 67 | 68 | flops = compute_flops(model_name, input_size, custom_op_name) 69 | # print(f"FLOPs for {model_name}: {flops}") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Training Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 5 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 6 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 7 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 8 | 9 | This script was started from an early version of the PyTorch ImageNet example 10 | (https://github.com/pytorch/examples/tree/master/imagenet) 11 | 12 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 13 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 16 | """ 17 | import argparse 18 | import importlib 19 | import json 20 | import logging 21 | import os 22 | import time 23 | from collections import OrderedDict 24 | from contextlib import suppress 25 | from datetime import datetime 26 | from functools import partial 27 | 28 | import torch 29 | import torch.nn as nn 30 | import torchvision.utils 31 | import yaml 32 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 33 | 34 | from timm import utils 35 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 36 | from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm 37 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy 38 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters 39 | from timm.optim import create_optimizer_v2, optimizer_kwargs 40 | from timm.scheduler import create_scheduler_v2, scheduler_kwargs 41 | from timm.utils import ApexScaler, NativeScaler 42 | 43 | try: 44 | from apex import amp 45 | from apex.parallel import DistributedDataParallel as ApexDDP 46 | from apex.parallel import convert_syncbn_model 47 | has_apex = True 48 | except ImportError: 49 | has_apex = False 50 | 51 | has_native_amp = False 52 | try: 53 | if getattr(torch.cuda.amp, 'autocast') is not None: 54 | has_native_amp = True 55 | except AttributeError: 56 | pass 57 | 58 | try: 59 | import wandb 60 | has_wandb = True 61 | except ImportError: 62 | has_wandb = False 63 | 64 | try: 65 | from functorch.compile import memory_efficient_fusion 66 | has_functorch = True 67 | except ImportError as e: 68 | has_functorch = False 69 | 70 | import katransformer 71 | 72 | has_compile = hasattr(torch, 'compile') 73 | 74 | 75 | _logger = logging.getLogger('train') 76 | 77 | # The first arg parser parses out only the --config argument, this argument is used to 78 | # load a yaml file containing key-values that override the defaults for the main parser below 79 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 80 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 81 | help='YAML config file specifying default arguments') 82 | 83 | 84 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 85 | 86 | # Dataset parameters 87 | group = parser.add_argument_group('Dataset parameters') 88 | # Keep this argument outside the dataset group because it is positional. 89 | parser.add_argument('data', nargs='?', metavar='DIR', const=None, 90 | help='path to dataset (positional is *deprecated*, use --data-dir)') 91 | parser.add_argument('--data-dir', metavar='DIR', 92 | help='path to dataset (root dir)') 93 | parser.add_argument('--dataset', metavar='NAME', default='', 94 | help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') 95 | group.add_argument('--train-split', metavar='NAME', default='train', 96 | help='dataset train split (default: train)') 97 | group.add_argument('--val-split', metavar='NAME', default='validation', 98 | help='dataset validation split (default: validation)') 99 | parser.add_argument('--train-num-samples', default=None, type=int, 100 | metavar='N', help='Manually specify num samples in train split, for IterableDatasets.') 101 | parser.add_argument('--val-num-samples', default=None, type=int, 102 | metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.') 103 | group.add_argument('--dataset-download', action='store_true', default=False, 104 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 105 | group.add_argument('--class-map', default='', type=str, metavar='FILENAME', 106 | help='path to class to idx mapping file (default: "")') 107 | group.add_argument('--input-img-mode', default=None, type=str, 108 | help='Dataset image conversion mode for input images.') 109 | group.add_argument('--input-key', default=None, type=str, 110 | help='Dataset key for input images.') 111 | group.add_argument('--target-key', default=None, type=str, 112 | help='Dataset key for target labels.') 113 | 114 | # Model parameters 115 | group = parser.add_argument_group('Model parameters') 116 | group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', 117 | help='Name of model to train (default: "resnet50")') 118 | group.add_argument('--pretrained', action='store_true', default=False, 119 | help='Start with pretrained version of specified network (if avail)') 120 | group.add_argument('--pretrained-path', default=None, type=str, 121 | help='Load this checkpoint as if they were the pretrained weights (with adaptation).') 122 | group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 123 | help='Load this checkpoint into model after initialization (default: none)') 124 | group.add_argument('--resume', default='', type=str, metavar='PATH', 125 | help='Resume full model and optimizer state from checkpoint (default: none)') 126 | group.add_argument('--no-resume-opt', action='store_true', default=False, 127 | help='prevent resume of optimizer state when resuming model') 128 | group.add_argument('--num-classes', type=int, default=None, metavar='N', 129 | help='number of label classes (Model default if None)') 130 | group.add_argument('--gp', default=None, type=str, metavar='POOL', 131 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 132 | group.add_argument('--img-size', type=int, default=None, metavar='N', 133 | help='Image size (default: None => model default)') 134 | group.add_argument('--in-chans', type=int, default=None, metavar='N', 135 | help='Image input channels (default: None => 3)') 136 | group.add_argument('--input-size', default=None, nargs=3, type=int, 137 | metavar='N N N', 138 | help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 139 | group.add_argument('--crop-pct', default=None, type=float, 140 | metavar='N', help='Input image center crop percent (for validation only)') 141 | group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 142 | help='Override mean pixel value of dataset') 143 | group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 144 | help='Override std deviation of dataset') 145 | group.add_argument('--interpolation', default='', type=str, metavar='NAME', 146 | help='Image resize interpolation type (overrides model)') 147 | group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 148 | help='Input batch size for training (default: 128)') 149 | group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 150 | help='Validation batch size override (default: None)') 151 | group.add_argument('--channels-last', action='store_true', default=False, 152 | help='Use channels_last memory layout') 153 | group.add_argument('--fuser', default='', type=str, 154 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 155 | group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N', 156 | help='The number of steps to accumulate gradients (default: 1)') 157 | group.add_argument('--grad-checkpointing', action='store_true', default=False, 158 | help='Enable gradient checkpointing through model blocks/stages') 159 | group.add_argument('--fast-norm', default=False, action='store_true', 160 | help='enable experimental fast-norm') 161 | group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs) 162 | group.add_argument('--head-init-scale', default=None, type=float, 163 | help='Head initialization scale') 164 | group.add_argument('--head-init-bias', default=None, type=float, 165 | help='Head initialization bias value') 166 | 167 | # scripting / codegen 168 | scripting_group = group.add_mutually_exclusive_group() 169 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', 170 | help='torch.jit.script the full model') 171 | scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', 172 | help="Enable compilation w/ specified backend (default: inductor).") 173 | 174 | # Device & distributed 175 | group = parser.add_argument_group('Device parameters') 176 | group.add_argument('--device', default='cuda', type=str, 177 | help="Device (accelerator) to use.") 178 | group.add_argument('--amp', action='store_true', default=False, 179 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 180 | group.add_argument('--amp-dtype', default='float16', type=str, 181 | help='lower precision AMP dtype (default: float16)') 182 | group.add_argument('--amp-impl', default='native', type=str, 183 | help='AMP impl to use, "native" or "apex" (default: native)') 184 | group.add_argument('--no-ddp-bb', action='store_true', default=False, 185 | help='Force broadcast buffers for native DDP to off.') 186 | group.add_argument('--synchronize-step', action='store_true', default=False, 187 | help='torch.cuda.synchronize() end of each step') 188 | group.add_argument("--local_rank", default=0, type=int) 189 | parser.add_argument('--device-modules', default=None, type=str, nargs='+', 190 | help="Python imports for device backend modules.") 191 | 192 | # Optimizer parameters 193 | group = parser.add_argument_group('Optimizer parameters') 194 | group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 195 | help='Optimizer (default: "sgd")') 196 | group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 197 | help='Optimizer Epsilon (default: None, use opt default)') 198 | group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 199 | help='Optimizer Betas (default: None, use opt default)') 200 | group.add_argument('--momentum', type=float, default=0.9, metavar='M', 201 | help='Optimizer momentum (default: 0.9)') 202 | group.add_argument('--weight-decay', type=float, default=2e-5, 203 | help='weight decay (default: 2e-5)') 204 | group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 205 | help='Clip gradient norm (default: None, no clipping)') 206 | group.add_argument('--clip-mode', type=str, default='norm', 207 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 208 | group.add_argument('--layer-decay', type=float, default=None, 209 | help='layer-wise learning rate decay (default: None)') 210 | group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs) 211 | 212 | # Learning rate schedule parameters 213 | group = parser.add_argument_group('Learning rate schedule parameters') 214 | group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', 215 | help='LR scheduler (default: "step"') 216 | group.add_argument('--sched-on-updates', action='store_true', default=False, 217 | help='Apply LR scheduler step on update instead of epoch end.') 218 | group.add_argument('--lr', type=float, default=None, metavar='LR', 219 | help='learning rate, overrides lr-base if set (default: None)') 220 | group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', 221 | help='base learning rate: lr = lr_base * global_batch_size / base_size') 222 | group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', 223 | help='base learning rate batch size (divisor, default: 256).') 224 | group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', 225 | help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') 226 | group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 227 | help='learning rate noise on/off epoch percentages') 228 | group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 229 | help='learning rate noise limit percent (default: 0.67)') 230 | group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 231 | help='learning rate noise std-dev (default: 1.0)') 232 | group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 233 | help='learning rate cycle len multiplier (default: 1.0)') 234 | group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 235 | help='amount to decay each learning rate cycle (default: 0.5)') 236 | group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 237 | help='learning rate cycle limit, cycles enabled if > 1') 238 | group.add_argument('--lr-k-decay', type=float, default=1.0, 239 | help='learning rate k-decay for cosine/poly (default: 1.0)') 240 | group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', 241 | help='warmup learning rate (default: 1e-5)') 242 | group.add_argument('--min-lr', type=float, default=0, metavar='LR', 243 | help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') 244 | group.add_argument('--epochs', type=int, default=300, metavar='N', 245 | help='number of epochs to train (default: 300)') 246 | group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 247 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 248 | group.add_argument('--start-epoch', default=None, type=int, metavar='N', 249 | help='manual epoch number (useful on restarts)') 250 | group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", 251 | help='list of decay epoch indices for multistep lr. must be increasing') 252 | group.add_argument('--decay-epochs', type=float, default=90, metavar='N', 253 | help='epoch interval to decay LR') 254 | group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 255 | help='epochs to warmup LR, if scheduler supports') 256 | group.add_argument('--warmup-prefix', action='store_true', default=False, 257 | help='Exclude warmup period from decay schedule.'), 258 | group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', 259 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 260 | group.add_argument('--patience-epochs', type=int, default=10, metavar='N', 261 | help='patience epochs for Plateau LR scheduler (default: 10)') 262 | group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 263 | help='LR decay rate (default: 0.1)') 264 | 265 | # Augmentation & regularization parameters 266 | group = parser.add_argument_group('Augmentation and regularization parameters') 267 | group.add_argument('--no-aug', action='store_true', default=False, 268 | help='Disable all training augmentation, override other train aug args') 269 | group.add_argument('--train-crop-mode', type=str, default=None, 270 | help='Crop-mode in train'), 271 | group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 272 | help='Random resize scale (default: 0.08 1.0)') 273 | group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', 274 | help='Random resize aspect ratio (default: 0.75 1.33)') 275 | group.add_argument('--hflip', type=float, default=0.5, 276 | help='Horizontal flip training aug probability') 277 | group.add_argument('--vflip', type=float, default=0., 278 | help='Vertical flip training aug probability') 279 | group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 280 | help='Color jitter factor (default: 0.4)') 281 | group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT', 282 | help='Probability of applying any color jitter.') 283 | group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT', 284 | help='Probability of applying random grayscale conversion.') 285 | group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT', 286 | help='Probability of applying gaussian blur.') 287 | group.add_argument('--aa', type=str, default=None, metavar='NAME', 288 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 289 | group.add_argument('--aug-repeats', type=float, default=0, 290 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 291 | group.add_argument('--aug-splits', type=int, default=0, 292 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 293 | group.add_argument('--jsd-loss', action='store_true', default=False, 294 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 295 | group.add_argument('--bce-loss', action='store_true', default=False, 296 | help='Enable BCE loss w/ Mixup/CutMix use.') 297 | group.add_argument('--bce-sum', action='store_true', default=False, 298 | help='Sum over classes when using BCE loss.') 299 | group.add_argument('--bce-target-thresh', type=float, default=None, 300 | help='Threshold for binarizing softened BCE targets (default: None, disabled).') 301 | group.add_argument('--bce-pos-weight', type=float, default=None, 302 | help='Positive weighting for BCE loss.') 303 | group.add_argument('--reprob', type=float, default=0., metavar='PCT', 304 | help='Random erase prob (default: 0.)') 305 | group.add_argument('--remode', type=str, default='pixel', 306 | help='Random erase mode (default: "pixel")') 307 | group.add_argument('--recount', type=int, default=1, 308 | help='Random erase count (default: 1)') 309 | group.add_argument('--resplit', action='store_true', default=False, 310 | help='Do not random erase first (clean) augmentation split') 311 | group.add_argument('--mixup', type=float, default=0.0, 312 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 313 | group.add_argument('--cutmix', type=float, default=0.0, 314 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 315 | group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 316 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 317 | group.add_argument('--mixup-prob', type=float, default=1.0, 318 | help='Probability of performing mixup or cutmix when either/both is enabled') 319 | group.add_argument('--mixup-switch-prob', type=float, default=0.5, 320 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 321 | group.add_argument('--mixup-mode', type=str, default='batch', 322 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 323 | group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 324 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 325 | group.add_argument('--smoothing', type=float, default=0.1, 326 | help='Label smoothing (default: 0.1)') 327 | group.add_argument('--train-interpolation', type=str, default='random', 328 | help='Training interpolation (random, bilinear, bicubic default: "random")') 329 | group.add_argument('--drop', type=float, default=0.0, metavar='PCT', 330 | help='Dropout rate (default: 0.)') 331 | group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 332 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 333 | group.add_argument('--drop-path', type=float, default=None, metavar='PCT', 334 | help='Drop path rate (default: None)') 335 | group.add_argument('--drop-block', type=float, default=None, metavar='PCT', 336 | help='Drop block rate (default: None)') 337 | 338 | # Batch norm parameters (only works with gen_efficientnet based models currently) 339 | group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') 340 | group.add_argument('--bn-momentum', type=float, default=None, 341 | help='BatchNorm momentum override (if not None)') 342 | group.add_argument('--bn-eps', type=float, default=None, 343 | help='BatchNorm epsilon override (if not None)') 344 | group.add_argument('--sync-bn', action='store_true', 345 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 346 | group.add_argument('--dist-bn', type=str, default='reduce', 347 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 348 | group.add_argument('--split-bn', action='store_true', 349 | help='Enable separate BN layers per augmentation split.') 350 | 351 | # Model Exponential Moving Average 352 | group = parser.add_argument_group('Model exponential moving average parameters') 353 | group.add_argument('--model-ema', action='store_true', default=False, 354 | help='Enable tracking moving average of model weights.') 355 | group.add_argument('--model-ema-force-cpu', action='store_true', default=False, 356 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 357 | group.add_argument('--model-ema-decay', type=float, default=0.9998, 358 | help='Decay factor for model weights moving average (default: 0.9998)') 359 | group.add_argument('--model-ema-warmup', action='store_true', 360 | help='Enable warmup for model EMA decay.') 361 | 362 | # Misc 363 | group = parser.add_argument_group('Miscellaneous parameters') 364 | group.add_argument('--seed', type=int, default=42, metavar='S', 365 | help='random seed (default: 42)') 366 | group.add_argument('--worker-seeding', type=str, default='all', 367 | help='worker seed mode (default: all)') 368 | group.add_argument('--log-interval', type=int, default=50, metavar='N', 369 | help='how many batches to wait before logging training status') 370 | group.add_argument('--recovery-interval', type=int, default=0, metavar='N', 371 | help='how many batches to wait before writing recovery checkpoint') 372 | group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 373 | help='number of checkpoints to keep (default: 10)') 374 | group.add_argument('-j', '--workers', type=int, default=4, metavar='N', 375 | help='how many training processes to use (default: 4)') 376 | group.add_argument('--save-images', action='store_true', default=False, 377 | help='save images of input bathes every log interval for debugging') 378 | group.add_argument('--pin-mem', action='store_true', default=False, 379 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 380 | group.add_argument('--no-prefetcher', action='store_true', default=False, 381 | help='disable fast prefetcher') 382 | group.add_argument('--output', default='', type=str, metavar='PATH', 383 | help='path to output folder (default: none, current dir)') 384 | group.add_argument('--experiment', default='', type=str, metavar='NAME', 385 | help='name of train experiment, name of sub-folder for output') 386 | group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 387 | help='Best metric (default: "top1"') 388 | group.add_argument('--tta', type=int, default=0, metavar='N', 389 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 390 | group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 391 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 392 | group.add_argument('--log-wandb', action='store_true', default=False, 393 | help='log training and validation metrics to wandb') 394 | 395 | 396 | def _parse_args(): 397 | # Do we have a config file to parse? 398 | args_config, remaining = config_parser.parse_known_args() 399 | if args_config.config: 400 | with open(args_config.config, 'r') as f: 401 | cfg = yaml.safe_load(f) 402 | parser.set_defaults(**cfg) 403 | 404 | # The main arg parser parses the rest of the args, the usual 405 | # defaults will have been overridden if config file specified. 406 | args = parser.parse_args(remaining) 407 | 408 | # Cache the args as a text string to save them in the output dir later 409 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 410 | return args, args_text 411 | 412 | 413 | def main(): 414 | utils.setup_default_logging() 415 | args, args_text = _parse_args() 416 | 417 | if args.device_modules: 418 | for module in args.device_modules: 419 | importlib.import_module(module) 420 | 421 | if torch.cuda.is_available(): 422 | torch.backends.cuda.matmul.allow_tf32 = True 423 | torch.backends.cudnn.benchmark = True 424 | 425 | args.prefetcher = not args.no_prefetcher 426 | args.grad_accum_steps = max(1, args.grad_accum_steps) 427 | device = utils.init_distributed_device(args) 428 | if args.distributed: 429 | _logger.info( 430 | 'Training in distributed mode with multiple processes, 1 device per process.' 431 | f'Process {args.rank}, total {args.world_size}, device {args.device}.') 432 | else: 433 | _logger.info(f'Training with a single process on 1 device ({args.device}).') 434 | assert args.rank >= 0 435 | 436 | # resolve AMP arguments based on PyTorch / Apex availability 437 | use_amp = None 438 | amp_dtype = torch.float16 439 | 440 | if args.amp: 441 | if args.amp_impl == 'apex': 442 | assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' 443 | use_amp = 'apex' 444 | assert args.amp_dtype == 'float16' 445 | else: 446 | assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' 447 | use_amp = 'native' 448 | assert args.amp_dtype in ('float16', 'bfloat16') 449 | if args.amp_dtype == 'bfloat16': 450 | amp_dtype = torch.bfloat16 451 | 452 | utils.random_seed(args.seed, args.rank) 453 | 454 | if args.fuser: 455 | utils.set_jit_fuser(args.fuser) 456 | if args.fast_norm: 457 | set_fast_norm() 458 | 459 | in_chans = 3 460 | if args.in_chans is not None: 461 | in_chans = args.in_chans 462 | elif args.input_size is not None: 463 | in_chans = args.input_size[0] 464 | 465 | factory_kwargs = {} 466 | if args.pretrained_path: 467 | # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'. 468 | factory_kwargs['pretrained_cfg_overlay'] = dict( 469 | file=args.pretrained_path, 470 | num_classes=-1, # force head adaptation 471 | ) 472 | 473 | model = create_model( 474 | args.model, 475 | pretrained=args.pretrained, 476 | in_chans=in_chans, 477 | num_classes=args.num_classes, 478 | drop_rate=args.drop, 479 | drop_path_rate=args.drop_path, 480 | drop_block_rate=args.drop_block, 481 | global_pool=args.gp, 482 | bn_momentum=args.bn_momentum, 483 | bn_eps=args.bn_eps, 484 | scriptable=args.torchscript, 485 | checkpoint_path=args.initial_checkpoint, 486 | **factory_kwargs, 487 | **args.model_kwargs, 488 | ) 489 | if args.head_init_scale is not None: 490 | with torch.no_grad(): 491 | model.get_classifier().weight.mul_(args.head_init_scale) 492 | model.get_classifier().bias.mul_(args.head_init_scale) 493 | if args.head_init_bias is not None: 494 | nn.init.constant_(model.get_classifier().bias, args.head_init_bias) 495 | 496 | if args.num_classes is None: 497 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 498 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 499 | 500 | if args.grad_checkpointing: 501 | model.set_grad_checkpointing(enable=True) 502 | 503 | if utils.is_primary(args): 504 | _logger.info( 505 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') 506 | print(model) 507 | 508 | data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) 509 | 510 | # setup augmentation batch splits for contrastive loss or split bn 511 | num_aug_splits = 0 512 | if args.aug_splits > 0: 513 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 514 | num_aug_splits = args.aug_splits 515 | 516 | # enable split bn (separate bn stats per batch-portion) 517 | if args.split_bn: 518 | assert num_aug_splits > 1 or args.resplit 519 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 520 | 521 | # move model to GPU, enable channels last layout if set 522 | model.to(device=device) 523 | if args.channels_last: 524 | model.to(memory_format=torch.channels_last) 525 | 526 | # setup synchronized BatchNorm for distributed training 527 | if args.distributed and args.sync_bn: 528 | args.dist_bn = '' # disable dist_bn when sync BN active 529 | assert not args.split_bn 530 | if has_apex and use_amp == 'apex': 531 | # Apex SyncBN used with Apex AMP 532 | # WARNING this won't currently work with models using BatchNormAct2d 533 | model = convert_syncbn_model(model) 534 | else: 535 | model = convert_sync_batchnorm(model) 536 | if utils.is_primary(args): 537 | _logger.info( 538 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 539 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 540 | 541 | if args.torchscript: 542 | assert not args.torchcompile 543 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 544 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 545 | model = torch.jit.script(model) 546 | 547 | if not args.lr: 548 | global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps 549 | batch_ratio = global_batch_size / args.lr_base_size 550 | if not args.lr_base_scale: 551 | on = args.opt.lower() 552 | args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear' 553 | if args.lr_base_scale == 'sqrt': 554 | batch_ratio = batch_ratio ** 0.5 555 | args.lr = args.lr_base * batch_ratio 556 | if utils.is_primary(args): 557 | _logger.info( 558 | f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' 559 | f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') 560 | 561 | optimizer = create_optimizer_v2( 562 | model, 563 | **optimizer_kwargs(cfg=args), 564 | **args.opt_kwargs, 565 | ) 566 | 567 | # setup automatic mixed-precision (AMP) loss scaling and op casting 568 | amp_autocast = suppress # do nothing 569 | loss_scaler = None 570 | if use_amp == 'apex': 571 | assert device.type == 'cuda' 572 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 573 | loss_scaler = ApexScaler() 574 | if utils.is_primary(args): 575 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 576 | elif use_amp == 'native': 577 | try: 578 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) 579 | except (AttributeError, TypeError): 580 | # fallback to CUDA only AMP for PyTorch < 1.10 581 | assert device.type == 'cuda' 582 | amp_autocast = torch.cuda.amp.autocast 583 | if device.type == 'cuda' and amp_dtype == torch.float16: 584 | # loss scaler only used for float16 (half) dtype, bfloat16 does not need it 585 | loss_scaler = NativeScaler() 586 | if utils.is_primary(args): 587 | _logger.info('Using native Torch AMP. Training in mixed precision.') 588 | else: 589 | if utils.is_primary(args): 590 | _logger.info('AMP not enabled. Training in float32.') 591 | 592 | # optionally resume from a checkpoint 593 | resume_epoch = None 594 | if args.resume: 595 | resume_epoch = resume_checkpoint( 596 | model, 597 | args.resume, 598 | optimizer=None if args.no_resume_opt else optimizer, 599 | loss_scaler=None if args.no_resume_opt else loss_scaler, 600 | log_info=utils.is_primary(args), 601 | ) 602 | 603 | # setup exponential moving average of model weights, SWA could be used here too 604 | model_ema = None 605 | if args.model_ema: 606 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 607 | model_ema = utils.ModelEmaV3( 608 | model, 609 | decay=args.model_ema_decay, 610 | use_warmup=args.model_ema_warmup, 611 | device='cpu' if args.model_ema_force_cpu else None, 612 | ) 613 | if args.resume: 614 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 615 | if args.torchcompile: 616 | model_ema = torch.compile(model_ema, backend=args.torchcompile) 617 | 618 | # setup distributed training 619 | if args.distributed: 620 | if has_apex and use_amp == 'apex': 621 | # Apex DDP preferred unless native amp is activated 622 | if utils.is_primary(args): 623 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 624 | model = ApexDDP(model, delay_allreduce=True) 625 | else: 626 | if utils.is_primary(args): 627 | _logger.info("Using native Torch DistributedDataParallel.") 628 | model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) 629 | # NOTE: EMA model does not need to be wrapped by DDP 630 | 631 | if args.torchcompile: 632 | # torch compile should be done after DDP 633 | assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' 634 | model = torch.compile(model, backend=args.torchcompile) 635 | 636 | # create the train and eval datasets 637 | if args.data and not args.data_dir: 638 | args.data_dir = args.data 639 | if args.input_img_mode is None: 640 | input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L' 641 | else: 642 | input_img_mode = args.input_img_mode 643 | 644 | dataset_train = create_dataset( 645 | args.dataset, 646 | root=args.data_dir, 647 | split=args.train_split, 648 | is_training=True, 649 | class_map=args.class_map, 650 | download=args.dataset_download, 651 | batch_size=args.batch_size, 652 | seed=args.seed, 653 | repeats=args.epoch_repeats, 654 | input_img_mode=input_img_mode, 655 | input_key=args.input_key, 656 | target_key=args.target_key, 657 | num_samples=args.train_num_samples, 658 | ) 659 | 660 | if args.val_split: 661 | dataset_eval = create_dataset( 662 | args.dataset, 663 | root=args.data_dir, 664 | split=args.val_split, 665 | is_training=False, 666 | class_map=args.class_map, 667 | download=args.dataset_download, 668 | batch_size=args.batch_size, 669 | input_img_mode=input_img_mode, 670 | input_key=args.input_key, 671 | target_key=args.target_key, 672 | num_samples=args.val_num_samples, 673 | ) 674 | 675 | # setup mixup / cutmix 676 | collate_fn = None 677 | mixup_fn = None 678 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 679 | if mixup_active: 680 | mixup_args = dict( 681 | mixup_alpha=args.mixup, 682 | cutmix_alpha=args.cutmix, 683 | cutmix_minmax=args.cutmix_minmax, 684 | prob=args.mixup_prob, 685 | switch_prob=args.mixup_switch_prob, 686 | mode=args.mixup_mode, 687 | label_smoothing=args.smoothing, 688 | num_classes=args.num_classes 689 | ) 690 | if args.prefetcher: 691 | assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup) 692 | collate_fn = FastCollateMixup(**mixup_args) 693 | else: 694 | mixup_fn = Mixup(**mixup_args) 695 | 696 | # wrap dataset in AugMix helper 697 | if num_aug_splits > 1: 698 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 699 | 700 | # create data loaders w/ augmentation pipeline 701 | train_interpolation = args.train_interpolation 702 | if args.no_aug or not train_interpolation: 703 | train_interpolation = data_config['interpolation'] 704 | loader_train = create_loader( 705 | dataset_train, 706 | input_size=data_config['input_size'], 707 | batch_size=args.batch_size, 708 | is_training=True, 709 | no_aug=args.no_aug, 710 | re_prob=args.reprob, 711 | re_mode=args.remode, 712 | re_count=args.recount, 713 | re_split=args.resplit, 714 | train_crop_mode=args.train_crop_mode, 715 | scale=args.scale, 716 | ratio=args.ratio, 717 | hflip=args.hflip, 718 | vflip=args.vflip, 719 | color_jitter=args.color_jitter, 720 | color_jitter_prob=args.color_jitter_prob, 721 | grayscale_prob=args.grayscale_prob, 722 | gaussian_blur_prob=args.gaussian_blur_prob, 723 | auto_augment=args.aa, 724 | num_aug_repeats=args.aug_repeats, 725 | num_aug_splits=num_aug_splits, 726 | interpolation=train_interpolation, 727 | mean=data_config['mean'], 728 | std=data_config['std'], 729 | num_workers=args.workers, 730 | distributed=args.distributed, 731 | collate_fn=collate_fn, 732 | pin_memory=args.pin_mem, 733 | device=device, 734 | use_prefetcher=args.prefetcher, 735 | use_multi_epochs_loader=args.use_multi_epochs_loader, 736 | worker_seeding=args.worker_seeding, 737 | ) 738 | 739 | loader_eval = None 740 | if args.val_split: 741 | eval_workers = args.workers 742 | if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): 743 | # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training 744 | eval_workers = min(2, args.workers) 745 | loader_eval = create_loader( 746 | dataset_eval, 747 | input_size=data_config['input_size'], 748 | batch_size=args.validation_batch_size or args.batch_size, 749 | is_training=False, 750 | interpolation=data_config['interpolation'], 751 | mean=data_config['mean'], 752 | std=data_config['std'], 753 | num_workers=eval_workers, 754 | distributed=args.distributed, 755 | crop_pct=data_config['crop_pct'], 756 | pin_memory=args.pin_mem, 757 | device=device, 758 | use_prefetcher=args.prefetcher, 759 | ) 760 | 761 | # setup loss function 762 | if args.jsd_loss: 763 | assert num_aug_splits > 1 # JSD only valid with aug splits set 764 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) 765 | elif mixup_active: 766 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 767 | if args.bce_loss: 768 | train_loss_fn = BinaryCrossEntropy( 769 | target_threshold=args.bce_target_thresh, 770 | sum_classes=args.bce_sum, 771 | pos_weight=args.bce_pos_weight, 772 | ) 773 | else: 774 | train_loss_fn = SoftTargetCrossEntropy() 775 | elif args.smoothing: 776 | if args.bce_loss: 777 | train_loss_fn = BinaryCrossEntropy( 778 | smoothing=args.smoothing, 779 | target_threshold=args.bce_target_thresh, 780 | sum_classes=args.bce_sum, 781 | pos_weight=args.bce_pos_weight, 782 | ) 783 | else: 784 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 785 | else: 786 | train_loss_fn = nn.CrossEntropyLoss() 787 | train_loss_fn = train_loss_fn.to(device=device) 788 | validate_loss_fn = nn.CrossEntropyLoss().to(device=device) 789 | 790 | # setup checkpoint saver and eval metric tracking 791 | eval_metric = args.eval_metric if loader_eval is not None else 'loss' 792 | decreasing_metric = eval_metric == 'loss' 793 | best_metric = None 794 | best_epoch = None 795 | saver = None 796 | output_dir = None 797 | if utils.is_primary(args): 798 | if args.experiment: 799 | exp_name = args.experiment 800 | else: 801 | exp_name = '-'.join([ 802 | datetime.now().strftime("%Y%m%d-%H%M%S"), 803 | safe_model_name(args.model), 804 | str(data_config['input_size'][-1]) 805 | ]) 806 | output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) 807 | saver = utils.CheckpointSaver( 808 | model=model, 809 | optimizer=optimizer, 810 | args=args, 811 | model_ema=model_ema, 812 | amp_scaler=loss_scaler, 813 | checkpoint_dir=output_dir, 814 | recovery_dir=output_dir, 815 | decreasing=decreasing_metric, 816 | max_history=args.checkpoint_hist 817 | ) 818 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 819 | f.write(args_text) 820 | 821 | if utils.is_primary(args) and args.log_wandb: 822 | if has_wandb: 823 | wandb.init(project="scale-kan", 824 | name=exp_name, 825 | config=args) 826 | else: 827 | _logger.warning( 828 | "You've requested to log metrics to wandb but package not found. " 829 | "Metrics not being logged to wandb, try `pip install wandb`") 830 | 831 | # setup learning rate schedule and starting epoch 832 | updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps 833 | lr_scheduler, num_epochs = create_scheduler_v2( 834 | optimizer, 835 | **scheduler_kwargs(args, decreasing_metric=decreasing_metric), 836 | updates_per_epoch=updates_per_epoch, 837 | ) 838 | start_epoch = 0 839 | if args.start_epoch is not None: 840 | # a specified start_epoch will always override the resume epoch 841 | start_epoch = args.start_epoch 842 | elif resume_epoch is not None: 843 | start_epoch = resume_epoch 844 | if lr_scheduler is not None and start_epoch > 0: 845 | if args.sched_on_updates: 846 | lr_scheduler.step_update(start_epoch * updates_per_epoch) 847 | else: 848 | lr_scheduler.step(start_epoch) 849 | 850 | if utils.is_primary(args): 851 | _logger.info( 852 | f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') 853 | 854 | results = [] 855 | try: 856 | for epoch in range(start_epoch, num_epochs): 857 | if hasattr(dataset_train, 'set_epoch'): 858 | dataset_train.set_epoch(epoch) 859 | elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 860 | loader_train.sampler.set_epoch(epoch) 861 | 862 | train_metrics = train_one_epoch( 863 | epoch, 864 | model, 865 | loader_train, 866 | optimizer, 867 | train_loss_fn, 868 | args, 869 | lr_scheduler=lr_scheduler, 870 | saver=saver, 871 | output_dir=output_dir, 872 | amp_autocast=amp_autocast, 873 | loss_scaler=loss_scaler, 874 | model_ema=model_ema, 875 | mixup_fn=mixup_fn, 876 | num_updates_total=num_epochs * updates_per_epoch, 877 | ) 878 | 879 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 880 | if utils.is_primary(args): 881 | _logger.info("Distributing BatchNorm running means and vars") 882 | utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 883 | 884 | if loader_eval is not None: 885 | eval_metrics = validate( 886 | model, 887 | loader_eval, 888 | validate_loss_fn, 889 | args, 890 | device=device, 891 | amp_autocast=amp_autocast, 892 | ) 893 | 894 | if model_ema is not None and not args.model_ema_force_cpu: 895 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 896 | utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 897 | 898 | ema_eval_metrics = validate( 899 | model_ema, 900 | loader_eval, 901 | validate_loss_fn, 902 | args, 903 | device=device, 904 | amp_autocast=amp_autocast, 905 | log_suffix=' (EMA)', 906 | ) 907 | eval_metrics = ema_eval_metrics 908 | else: 909 | eval_metrics = None 910 | 911 | if output_dir is not None: 912 | lrs = [param_group['lr'] for param_group in optimizer.param_groups] 913 | utils.update_summary( 914 | epoch, 915 | train_metrics, 916 | eval_metrics, 917 | filename=os.path.join(output_dir, 'summary.csv'), 918 | lr=sum(lrs) / len(lrs), 919 | write_header=best_metric is None, 920 | log_wandb=args.log_wandb and has_wandb, 921 | ) 922 | 923 | if eval_metrics is not None: 924 | latest_metric = eval_metrics[eval_metric] 925 | else: 926 | latest_metric = train_metrics[eval_metric] 927 | 928 | if saver is not None: 929 | # save proper checkpoint with eval metric 930 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric) 931 | 932 | if lr_scheduler is not None: 933 | # step LR for next epoch 934 | lr_scheduler.step(epoch + 1, latest_metric) 935 | 936 | results.append({ 937 | 'epoch': epoch, 938 | 'train': train_metrics, 939 | 'validation': eval_metrics, 940 | }) 941 | 942 | except KeyboardInterrupt: 943 | pass 944 | 945 | results = {'all': results} 946 | if best_metric is not None: 947 | results['best'] = results['all'][best_epoch - start_epoch] 948 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 949 | print(f'--result\n{json.dumps(results, indent=4)}') 950 | 951 | 952 | def train_one_epoch( 953 | epoch, 954 | model, 955 | loader, 956 | optimizer, 957 | loss_fn, 958 | args, 959 | device=torch.device('cuda'), 960 | lr_scheduler=None, 961 | saver=None, 962 | output_dir=None, 963 | amp_autocast=suppress, 964 | loss_scaler=None, 965 | model_ema=None, 966 | mixup_fn=None, 967 | num_updates_total=None, 968 | ): 969 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 970 | if args.prefetcher and loader.mixup_enabled: 971 | loader.mixup_enabled = False 972 | elif mixup_fn is not None: 973 | mixup_fn.mixup_enabled = False 974 | 975 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 976 | has_no_sync = hasattr(model, "no_sync") 977 | update_time_m = utils.AverageMeter() 978 | data_time_m = utils.AverageMeter() 979 | losses_m = utils.AverageMeter() 980 | 981 | model.train() 982 | 983 | accum_steps = args.grad_accum_steps 984 | last_accum_steps = len(loader) % accum_steps 985 | updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps 986 | num_updates = epoch * updates_per_epoch 987 | last_batch_idx = len(loader) - 1 988 | last_batch_idx_to_accum = len(loader) - last_accum_steps 989 | 990 | data_start_time = update_start_time = time.time() 991 | optimizer.zero_grad() 992 | update_sample_count = 0 993 | for batch_idx, (input, target) in enumerate(loader): 994 | last_batch = batch_idx == last_batch_idx 995 | need_update = last_batch or (batch_idx + 1) % accum_steps == 0 996 | update_idx = batch_idx // accum_steps 997 | if batch_idx >= last_batch_idx_to_accum: 998 | accum_steps = last_accum_steps 999 | 1000 | if not args.prefetcher: 1001 | input, target = input.to(device), target.to(device) 1002 | if mixup_fn is not None: 1003 | input, target = mixup_fn(input, target) 1004 | if args.channels_last: 1005 | input = input.contiguous(memory_format=torch.channels_last) 1006 | 1007 | # multiply by accum steps to get equivalent for full update 1008 | data_time_m.update(accum_steps * (time.time() - data_start_time)) 1009 | 1010 | def _forward(): 1011 | with amp_autocast(): 1012 | output = model(input) 1013 | loss = loss_fn(output, target) 1014 | if accum_steps > 1: 1015 | loss /= accum_steps 1016 | return loss 1017 | 1018 | def _backward(_loss): 1019 | if loss_scaler is not None: 1020 | loss_scaler( 1021 | _loss, 1022 | optimizer, 1023 | clip_grad=args.clip_grad, 1024 | clip_mode=args.clip_mode, 1025 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), 1026 | create_graph=second_order, 1027 | need_update=need_update, 1028 | ) 1029 | else: 1030 | _loss.backward(create_graph=second_order) 1031 | if need_update: 1032 | if args.clip_grad is not None: 1033 | utils.dispatch_clip_grad( 1034 | model_parameters(model, exclude_head='agc' in args.clip_mode), 1035 | value=args.clip_grad, 1036 | mode=args.clip_mode, 1037 | ) 1038 | optimizer.step() 1039 | 1040 | if has_no_sync and not need_update: 1041 | with model.no_sync(): 1042 | loss = _forward() 1043 | _backward(loss) 1044 | else: 1045 | loss = _forward() 1046 | _backward(loss) 1047 | 1048 | if not args.distributed: 1049 | losses_m.update(loss.item() * accum_steps, input.size(0)) 1050 | update_sample_count += input.size(0) 1051 | 1052 | if not need_update: 1053 | data_start_time = time.time() 1054 | continue 1055 | 1056 | num_updates += 1 1057 | optimizer.zero_grad() 1058 | if model_ema is not None: 1059 | model_ema.update(model, step=num_updates) 1060 | 1061 | if args.synchronize_step and device.type == 'cuda': 1062 | torch.cuda.synchronize() 1063 | time_now = time.time() 1064 | update_time_m.update(time.time() - update_start_time) 1065 | update_start_time = time_now 1066 | 1067 | if update_idx % args.log_interval == 0: 1068 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 1069 | lr = sum(lrl) / len(lrl) 1070 | 1071 | if args.distributed: 1072 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 1073 | losses_m.update(reduced_loss.item() * accum_steps, input.size(0)) 1074 | update_sample_count *= args.world_size 1075 | 1076 | if utils.is_primary(args): 1077 | _logger.info( 1078 | f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} ' 1079 | f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] ' 1080 | f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) ' 1081 | f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s ' 1082 | f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) ' 1083 | f'LR: {lr:.3e} ' 1084 | f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})' 1085 | ) 1086 | 1087 | if args.save_images and output_dir: 1088 | torchvision.utils.save_image( 1089 | input, 1090 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 1091 | padding=0, 1092 | normalize=True 1093 | ) 1094 | 1095 | if saver is not None and args.recovery_interval and ( 1096 | (update_idx + 1) % args.recovery_interval == 0): 1097 | saver.save_recovery(epoch, batch_idx=update_idx) 1098 | 1099 | if lr_scheduler is not None: 1100 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 1101 | 1102 | update_sample_count = 0 1103 | data_start_time = time.time() 1104 | # end for 1105 | 1106 | if hasattr(optimizer, 'sync_lookahead'): 1107 | optimizer.sync_lookahead() 1108 | 1109 | return OrderedDict([('loss', losses_m.avg)]) 1110 | 1111 | 1112 | def validate( 1113 | model, 1114 | loader, 1115 | loss_fn, 1116 | args, 1117 | device=torch.device('cuda'), 1118 | amp_autocast=suppress, 1119 | log_suffix='' 1120 | ): 1121 | batch_time_m = utils.AverageMeter() 1122 | losses_m = utils.AverageMeter() 1123 | top1_m = utils.AverageMeter() 1124 | top5_m = utils.AverageMeter() 1125 | 1126 | model.eval() 1127 | 1128 | end = time.time() 1129 | last_idx = len(loader) - 1 1130 | with torch.no_grad(): 1131 | for batch_idx, (input, target) in enumerate(loader): 1132 | last_batch = batch_idx == last_idx 1133 | if not args.prefetcher: 1134 | input = input.to(device) 1135 | target = target.to(device) 1136 | if args.channels_last: 1137 | input = input.contiguous(memory_format=torch.channels_last) 1138 | 1139 | with amp_autocast(): 1140 | output = model(input) 1141 | if isinstance(output, (tuple, list)): 1142 | output = output[0] 1143 | 1144 | # augmentation reduction 1145 | reduce_factor = args.tta 1146 | if reduce_factor > 1: 1147 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 1148 | target = target[0:target.size(0):reduce_factor] 1149 | 1150 | loss = loss_fn(output, target) 1151 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 1152 | 1153 | if args.distributed: 1154 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 1155 | acc1 = utils.reduce_tensor(acc1, args.world_size) 1156 | acc5 = utils.reduce_tensor(acc5, args.world_size) 1157 | else: 1158 | reduced_loss = loss.data 1159 | 1160 | if device.type == 'cuda': 1161 | torch.cuda.synchronize() 1162 | 1163 | losses_m.update(reduced_loss.item(), input.size(0)) 1164 | top1_m.update(acc1.item(), output.size(0)) 1165 | top5_m.update(acc5.item(), output.size(0)) 1166 | 1167 | batch_time_m.update(time.time() - end) 1168 | end = time.time() 1169 | if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): 1170 | log_name = 'Test' + log_suffix 1171 | _logger.info( 1172 | f'{log_name}: [{batch_idx:>4d}/{last_idx}] ' 1173 | f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) ' 1174 | f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) ' 1175 | f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) ' 1176 | f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})' 1177 | ) 1178 | 1179 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 1180 | 1181 | return metrics 1182 | 1183 | 1184 | if __name__ == '__main__': 1185 | main() 1186 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Validation Script 3 | 4 | This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained 5 | models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes 6 | canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. 7 | 8 | Hacked together by Ross Wightman (https://github.com/rwightman) 9 | """ 10 | import argparse 11 | import csv 12 | import glob 13 | import json 14 | import logging 15 | import os 16 | import time 17 | from collections import OrderedDict 18 | from contextlib import suppress 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.parallel 24 | 25 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet 26 | from timm.layers import apply_test_time_pool, set_fast_norm 27 | from timm.models import create_model, load_checkpoint, is_model, list_models 28 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ 29 | decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model 30 | 31 | try: 32 | from apex import amp 33 | has_apex = True 34 | except ImportError: 35 | has_apex = False 36 | 37 | has_native_amp = False 38 | try: 39 | if getattr(torch.cuda.amp, 'autocast') is not None: 40 | has_native_amp = True 41 | except AttributeError: 42 | pass 43 | 44 | try: 45 | from functorch.compile import memory_efficient_fusion 46 | has_functorch = True 47 | except ImportError as e: 48 | has_functorch = False 49 | 50 | import katransformer 51 | 52 | has_compile = hasattr(torch, 'compile') 53 | 54 | _logger = logging.getLogger('validate') 55 | 56 | 57 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 58 | parser.add_argument('data', nargs='?', metavar='DIR', const=None, 59 | help='path to dataset (*deprecated*, use --data-dir)') 60 | parser.add_argument('--data-dir', metavar='DIR', 61 | help='path to dataset (root dir)') 62 | parser.add_argument('--dataset', metavar='NAME', default='', 63 | help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') 64 | parser.add_argument('--split', metavar='NAME', default='validation', 65 | help='dataset split (default: validation)') 66 | parser.add_argument('--num-samples', default=None, type=int, 67 | metavar='N', help='Manually specify num samples in dataset split, for IterableDatasets.') 68 | parser.add_argument('--dataset-download', action='store_true', default=False, 69 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 70 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 71 | help='path to class to idx mapping file (default: "")') 72 | parser.add_argument('--input-key', default=None, type=str, 73 | help='Dataset key for input images.') 74 | parser.add_argument('--input-img-mode', default=None, type=str, 75 | help='Dataset image conversion mode for input images.') 76 | parser.add_argument('--target-key', default=None, type=str, 77 | help='Dataset key for target labels.') 78 | 79 | parser.add_argument('--model', '-m', metavar='NAME', default='kat_tiny_patch16_224', 80 | help='model architecture (default: kat_tiny_patch16_224)') 81 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 82 | help='use pre-trained model') 83 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 84 | help='number of data loading workers (default: 4)') 85 | parser.add_argument('-b', '--batch-size', default=256, type=int, 86 | metavar='N', help='mini-batch size (default: 256)') 87 | parser.add_argument('--img-size', default=None, type=int, 88 | metavar='N', help='Input image dimension, uses model default if empty') 89 | parser.add_argument('--in-chans', type=int, default=None, metavar='N', 90 | help='Image input channels (default: None => 3)') 91 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 92 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 93 | parser.add_argument('--use-train-size', action='store_true', default=False, 94 | help='force use of train input size, even when test size is specified in pretrained cfg') 95 | parser.add_argument('--crop-pct', default=None, type=float, 96 | metavar='N', help='Input image center crop pct') 97 | parser.add_argument('--crop-mode', default=None, type=str, 98 | metavar='N', help='Input image crop mode (squash, border, center). Model default if None.') 99 | parser.add_argument('--crop-border-pixels', type=int, default=None, 100 | help='Crop pixels from image border.') 101 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 102 | help='Override mean pixel value of dataset') 103 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 104 | help='Override std deviation of dataset') 105 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 106 | help='Image resize interpolation type (overrides model)') 107 | parser.add_argument('--num-classes', type=int, default=None, 108 | help='Number classes in dataset') 109 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 110 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 111 | parser.add_argument('--log-freq', default=10, type=int, 112 | metavar='N', help='batch logging frequency (default: 10)') 113 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 114 | help='path to latest checkpoint (default: none)') 115 | parser.add_argument('--num-gpu', type=int, default=1, 116 | help='Number of GPUS to use') 117 | parser.add_argument('--test-pool', dest='test_pool', action='store_true', 118 | help='enable test time pool') 119 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 120 | help='disable fast prefetcher') 121 | parser.add_argument('--pin-mem', action='store_true', default=False, 122 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 123 | parser.add_argument('--channels-last', action='store_true', default=False, 124 | help='Use channels_last memory layout') 125 | parser.add_argument('--device', default='cuda', type=str, 126 | help="Device (accelerator) to use.") 127 | parser.add_argument('--amp', action='store_true', default=False, 128 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 129 | parser.add_argument('--amp-dtype', default='float16', type=str, 130 | help='lower precision AMP dtype (default: float16)') 131 | parser.add_argument('--amp-impl', default='native', type=str, 132 | help='AMP impl to use, "native" or "apex" (default: native)') 133 | parser.add_argument('--tf-preprocessing', action='store_true', default=False, 134 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') 135 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 136 | help='use ema version of weights if present') 137 | parser.add_argument('--fuser', default='', type=str, 138 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 139 | parser.add_argument('--fast-norm', default=False, action='store_true', 140 | help='enable experimental fast-norm') 141 | parser.add_argument('--reparam', default=False, action='store_true', 142 | help='Reparameterize model') 143 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) 144 | 145 | 146 | scripting_group = parser.add_mutually_exclusive_group() 147 | scripting_group.add_argument('--torchscript', default=False, action='store_true', 148 | help='torch.jit.script the full model') 149 | scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor', 150 | help="Enable compilation w/ specified backend (default: inductor).") 151 | scripting_group.add_argument('--aot-autograd', default=False, action='store_true', 152 | help="Enable AOT Autograd support.") 153 | 154 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 155 | help='Output csv file for validation results (summary)') 156 | parser.add_argument('--results-format', default='csv', type=str, 157 | help='Format for results file one of (csv, json) (default: csv).') 158 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', 159 | help='Real labels JSON file for imagenet evaluation') 160 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', 161 | help='Valid label indices txt file for validation of partial label space') 162 | parser.add_argument('--retry', default=False, action='store_true', 163 | help='Enable batch size decay & retry for single model validation') 164 | 165 | 166 | def validate(args): 167 | # might as well try to validate something 168 | args.pretrained = args.pretrained or not args.checkpoint 169 | args.prefetcher = not args.no_prefetcher 170 | 171 | if torch.cuda.is_available(): 172 | torch.backends.cuda.matmul.allow_tf32 = True 173 | torch.backends.cudnn.benchmark = True 174 | 175 | device = torch.device(args.device) 176 | 177 | # resolve AMP arguments based on PyTorch / Apex availability 178 | use_amp = None 179 | amp_autocast = suppress 180 | if args.amp: 181 | if args.amp_impl == 'apex': 182 | assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' 183 | assert args.amp_dtype == 'float16' 184 | use_amp = 'apex' 185 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') 186 | else: 187 | assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' 188 | assert args.amp_dtype in ('float16', 'bfloat16') 189 | use_amp = 'native' 190 | amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 191 | amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) 192 | _logger.info('Validating in mixed precision with native PyTorch AMP.') 193 | else: 194 | _logger.info('Validating in float32. AMP not enabled.') 195 | 196 | if args.fuser: 197 | set_jit_fuser(args.fuser) 198 | 199 | if args.fast_norm: 200 | set_fast_norm() 201 | 202 | # create model 203 | in_chans = 3 204 | if args.in_chans is not None: 205 | in_chans = args.in_chans 206 | elif args.input_size is not None: 207 | in_chans = args.input_size[0] 208 | 209 | model = create_model( 210 | args.model, 211 | pretrained=args.pretrained, 212 | num_classes=args.num_classes, 213 | in_chans=in_chans, 214 | global_pool=args.gp, 215 | scriptable=args.torchscript, 216 | **args.model_kwargs, 217 | ) 218 | if args.num_classes is None: 219 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 220 | args.num_classes = model.num_classes 221 | 222 | if args.checkpoint: 223 | load_checkpoint(model, args.checkpoint, args.use_ema) 224 | 225 | if args.reparam: 226 | model = reparameterize_model(model) 227 | 228 | param_count = sum([m.numel() for m in model.parameters()]) 229 | _logger.info('Model %s created, param count: %d' % (args.model, param_count)) 230 | 231 | data_config = resolve_data_config( 232 | vars(args), 233 | model=model, 234 | use_test_size=not args.use_train_size, 235 | verbose=True, 236 | ) 237 | test_time_pool = False 238 | if args.test_pool: 239 | model, test_time_pool = apply_test_time_pool(model, data_config) 240 | 241 | model = model.to(device) 242 | if args.channels_last: 243 | model = model.to(memory_format=torch.channels_last) 244 | 245 | if args.torchscript: 246 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 247 | model = torch.jit.script(model) 248 | elif args.torchcompile: 249 | assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' 250 | torch._dynamo.reset() 251 | model = torch.compile(model, backend=args.torchcompile) 252 | elif args.aot_autograd: 253 | assert has_functorch, "functorch is needed for --aot-autograd" 254 | model = memory_efficient_fusion(model) 255 | 256 | if use_amp == 'apex': 257 | model = amp.initialize(model, opt_level='O1') 258 | 259 | if args.num_gpu > 1: 260 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) 261 | 262 | criterion = nn.CrossEntropyLoss().to(device) 263 | 264 | root_dir = args.data or args.data_dir 265 | if args.input_img_mode is None: 266 | input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L' 267 | else: 268 | input_img_mode = args.input_img_mode 269 | dataset = create_dataset( 270 | root=root_dir, 271 | name=args.dataset, 272 | split=args.split, 273 | download=args.dataset_download, 274 | load_bytes=args.tf_preprocessing, 275 | class_map=args.class_map, 276 | num_samples=args.num_samples, 277 | input_key=args.input_key, 278 | input_img_mode=input_img_mode, 279 | target_key=args.target_key, 280 | ) 281 | 282 | if args.valid_labels: 283 | with open(args.valid_labels, 'r') as f: 284 | valid_labels = [int(line.rstrip()) for line in f] 285 | else: 286 | valid_labels = None 287 | 288 | if args.real_labels: 289 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) 290 | else: 291 | real_labels = None 292 | 293 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] 294 | loader = create_loader( 295 | dataset, 296 | input_size=data_config['input_size'], 297 | batch_size=args.batch_size, 298 | use_prefetcher=args.prefetcher, 299 | interpolation=data_config['interpolation'], 300 | mean=data_config['mean'], 301 | std=data_config['std'], 302 | num_workers=args.workers, 303 | crop_pct=crop_pct, 304 | crop_mode=data_config['crop_mode'], 305 | crop_border_pixels=args.crop_border_pixels, 306 | pin_memory=args.pin_mem, 307 | device=device, 308 | tf_preprocessing=args.tf_preprocessing, 309 | ) 310 | 311 | batch_time = AverageMeter() 312 | losses = AverageMeter() 313 | top1 = AverageMeter() 314 | top5 = AverageMeter() 315 | 316 | model.eval() 317 | with torch.no_grad(): 318 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non 319 | input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device) 320 | if args.channels_last: 321 | input = input.contiguous(memory_format=torch.channels_last) 322 | with amp_autocast(): 323 | model(input) 324 | 325 | end = time.time() 326 | for batch_idx, (input, target) in enumerate(loader): 327 | if args.no_prefetcher: 328 | target = target.to(device) 329 | input = input.to(device) 330 | if args.channels_last: 331 | input = input.contiguous(memory_format=torch.channels_last) 332 | 333 | # compute output 334 | with amp_autocast(): 335 | output = model(input) 336 | 337 | if valid_labels is not None: 338 | output = output[:, valid_labels] 339 | loss = criterion(output, target) 340 | 341 | if real_labels is not None: 342 | real_labels.add_result(output) 343 | 344 | # measure accuracy and record loss 345 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) 346 | losses.update(loss.item(), input.size(0)) 347 | top1.update(acc1.item(), input.size(0)) 348 | top5.update(acc5.item(), input.size(0)) 349 | 350 | # measure elapsed time 351 | batch_time.update(time.time() - end) 352 | end = time.time() 353 | 354 | if batch_idx % args.log_freq == 0: 355 | _logger.info( 356 | 'Test: [{0:>4d}/{1}] ' 357 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 358 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 359 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 360 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 361 | batch_idx, 362 | len(loader), 363 | batch_time=batch_time, 364 | rate_avg=input.size(0) / batch_time.avg, 365 | loss=losses, 366 | top1=top1, 367 | top5=top5 368 | ) 369 | ) 370 | 371 | if real_labels is not None: 372 | # real labels mode replaces topk values at the end 373 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) 374 | else: 375 | top1a, top5a = top1.avg, top5.avg 376 | results = OrderedDict( 377 | model=args.model, 378 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4), 379 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4), 380 | param_count=round(param_count / 1e6, 2), 381 | img_size=data_config['input_size'][-1], 382 | crop_pct=crop_pct, 383 | interpolation=data_config['interpolation'], 384 | ) 385 | 386 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( 387 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 388 | 389 | return results 390 | 391 | 392 | def _try_run(args, initial_batch_size): 393 | batch_size = initial_batch_size 394 | results = OrderedDict() 395 | error_str = 'Unknown' 396 | while batch_size: 397 | args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case 398 | try: 399 | if torch.cuda.is_available() and 'cuda' in args.device: 400 | torch.cuda.empty_cache() 401 | results = validate(args) 402 | return results 403 | except RuntimeError as e: 404 | error_str = str(e) 405 | _logger.error(f'"{error_str}" while running validation.') 406 | if not check_batch_size_retry(error_str): 407 | break 408 | batch_size = decay_batch_step(batch_size) 409 | _logger.warning(f'Reducing batch size to {batch_size} for retry.') 410 | results['error'] = error_str 411 | _logger.error(f'{args.model} failed to validate ({error_str}).') 412 | return results 413 | 414 | 415 | _NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer'] 416 | 417 | 418 | def main(): 419 | setup_default_logging() 420 | args = parser.parse_args() 421 | model_cfgs = [] 422 | model_names = [] 423 | if os.path.isdir(args.checkpoint): 424 | # validate all checkpoints in a path with same model 425 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') 426 | checkpoints += glob.glob(args.checkpoint + '/*.pth') 427 | model_names = list_models(args.model) 428 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] 429 | else: 430 | if args.model == 'all': 431 | # validate all models in a list of names with pretrained checkpoints 432 | args.pretrained = True 433 | model_names = list_models( 434 | pretrained=True, 435 | exclude_filters=_NON_IN1K_FILTERS, 436 | ) 437 | model_cfgs = [(n, '') for n in model_names] 438 | elif not is_model(args.model): 439 | # model name doesn't exist, try as wildcard filter 440 | model_names = list_models( 441 | args.model, 442 | pretrained=True, 443 | ) 444 | model_cfgs = [(n, '') for n in model_names] 445 | 446 | if not model_cfgs and os.path.isfile(args.model): 447 | with open(args.model) as f: 448 | model_names = [line.rstrip() for line in f] 449 | model_cfgs = [(n, None) for n in model_names if n] 450 | 451 | if len(model_cfgs): 452 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 453 | results = [] 454 | try: 455 | initial_batch_size = args.batch_size 456 | for m, c in model_cfgs: 457 | args.model = m 458 | args.checkpoint = c 459 | r = _try_run(args, initial_batch_size) 460 | if 'error' in r: 461 | continue 462 | if args.checkpoint: 463 | r['checkpoint'] = args.checkpoint 464 | results.append(r) 465 | except KeyboardInterrupt as e: 466 | pass 467 | results = sorted(results, key=lambda x: x['top1'], reverse=True) 468 | else: 469 | if args.retry: 470 | results = _try_run(args, args.batch_size) 471 | else: 472 | results = validate(args) 473 | 474 | if args.results_file: 475 | write_results(args.results_file, results, format=args.results_format) 476 | 477 | # output results in JSON to stdout w/ delimiter for runner script 478 | print(f'--result\n{json.dumps(results, indent=4)}') 479 | 480 | 481 | def write_results(results_file, results, format='csv'): 482 | with open(results_file, mode='w') as cf: 483 | if format == 'json': 484 | json.dump(results, cf, indent=4) 485 | else: 486 | if not isinstance(results, (list, tuple)): 487 | results = [results] 488 | if not results: 489 | return 490 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 491 | dw.writeheader() 492 | for r in results: 493 | dw.writerow(r) 494 | cf.flush() 495 | 496 | 497 | 498 | if __name__ == '__main__': 499 | main() 500 | --------------------------------------------------------------------------------