├── .gitignore ├── LICENSE ├── README.md ├── crate_figure.png ├── data ├── dataset.py └── randomaug.py ├── figs ├── CRATE_arch.pdf ├── CRATE_fig1_patches.pdf ├── fig_arch.png ├── fig_arch_autoencoder.png ├── fig_layerwise.png ├── fig_masked_reconstruction.png ├── fig_objective.png ├── fig_pipeline.png ├── fig_seg.png └── fig_seg_headwise.png ├── finetune.py ├── main.py ├── model ├── crate.py ├── crate_ae │ ├── crate_ae.py │ ├── crate_decoder.py │ ├── crate_encoder.py │ └── pos_embed.py └── vit.py ├── requirements.txt ├── utils.py └── vis_utils ├── coding_rate.py ├── crate_hook.py ├── pca_visualization.py ├── plot.py └── shared_extractor.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoint/ 3 | log/ 4 | 5 | #ignore all pyache 6 | *.pyc 7 | 8 | #ignore dataset 9 | *.tar.gz 10 | #ignore folders in data 11 | data/* 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ma-Lab-Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRATE (Coding RAte reduction TransformEr) 2 | This repository is the official PyTorch implementation of the papers: 3 | 4 | - **White-Box Transformers via Sparse Rate Reduction** [**NeurIPS-2023**, [paper link](https://openreview.net/forum?id=THfl8hdVxH#)]. By [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley), [Benjamin D Haeffele](https://www.cis.jhu.edu/~haeffele/#about) (Johns Hopkins University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). 5 | - **Emergence of Segmentation with Minimalistic White-Box Transformers** [**CPAL-2024**, [paper link](https://arxiv.org/abs/2308.16271)]. By [Yaodong Yu](https://yaodongyu.github.io)* (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/)* (UC Berkeley & ShanghaiTech U), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley & NYU), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley & HKU). 2023. (* equal contribution) 6 | - **Masked Autoencoding via Structured Diffusion with White-Box Transformers** [**ICLR-2024**, [paper link](https://arxiv.org/abs/2404.02446)]. By [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com), [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). 7 | 8 | Also, we have released a larger journal-length overview paper of this line of research, which contains a superset of all the results presented above, and also more results in NLP and vision SSL. 9 | - **White-Box Transformers via Sparse Rate Reduction: Compression is All There Is?** [[paper link](https://arxiv.org/abs/2311.13110)]. By [Yaodong Yu](https://yaodongyu.github.io) (UC Berkeley), [Sam Buchanan](https://sdbuchanan.com) (TTIC), [Druv Pai](https://druvpai.github.io) (UC Berkeley), [Tianzhe Chu](https://tianzhechu.com/) (UC Berkeley), [Ziyang Wu](https://robinwu218.github.io/) (UC Berkeley), [Shengbang Tong](https://tsb0601.github.io/petertongsb/) (UC Berkeley), [Hao Bai](https://www.jackgethome.com/) (UIUC), [Yuexiang Zhai](https://yx-s-z.github.io/) (UC Berkeley), [Benjamin D Haeffele](https://www.cis.jhu.edu/~haeffele/#about) (Johns Hopkins University), and [Yi Ma](http://people.eecs.berkeley.edu/~yima/) (UC Berkeley). 10 | 11 | 12 | # Table of Contents 13 | 14 | * [CRATE (Coding RAte reduction TransformEr)](#crate-coding-rate-reduction-transformer) 15 | * [Theoretical Background: What is CRATE?](#theoretical-background-what-is-crate) 16 | * [1. CRATE Architecture overview](#1-crate-architecture-overview) 17 | * [2. One layer/block of CRATE](#2-one-layerblock-of-crate) 18 | * [3. Per-layer optimization in CRATE](#3-per-layer-optimization-in-crate) 19 | * [4. Segmentation visualization of CRATE](#4-segmentation-visualization-of-crate) 20 | * [Autoencoding](#autoencoding) 21 | * [Implementation and experiments](#implementation-and-experiments) 22 | * [Constructing a CRATE model](#constructing-a-crate-model) 23 | * [Pre-trained Checkpoints (ImageNet-1K)](#pre-trained-checkpoints-imagenet-1k) 24 | * [Training CRATE on ImageNet](#training-crate-on-imagenet) 25 | * [Finetuning pretrained / training random initialized CRATE on CIFAR10](#finetuning-pretrained--training-random-initialized-crate-on-cifar10) 26 | * [Demo: Emergent segmentation in CRATE](#demo-emergent-segmentation-in-crate) 27 | * [Constructing a CRATE autoencoding model](#constructing-a-crate-autoencoding-model) 28 | * [Pre-trained Checkpoints (ImageNet-1K)](#pre-trained-checkpoints-imagenet-1k-1) 29 | * [Training/Fine-Tuning CRATE-MAE](#trainingfine-tuning-crate-mae) 30 | * [Reference](#reference) 31 | 32 | ## Theoretical Background: What is CRATE? 33 | CRATE (Coding RAte reduction TransformEr) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the **sparse rate reduction objective** 34 |
35 |
36 |
38 | 39 | where $R$ and $R^{c}$ are different _coding rates_ for the input representations w.r.t.~different codebooks, and the $\ell^{0}$-norm promotes the sparsity of the final token representations $\boldsymbol{Z} = f(\boldsymbol{X})$. The function $f$ is defined as 40 | $$f=f^{L} \circ f^{L-1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},$$ 41 | where $f^{\mathrm{pre}}$ is the pre-processing mapping, and $f^{\ell}$ is the $\ell$-th layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, $f^{\ell}$ transforms the $\ell$-th layer token representations $\boldsymbol{Z}^{\ell}$ to $\boldsymbol{Z}^{\ell+1}$ via the $\texttt{MSSA}$ (Multi-Head Subspace Self-Attention) block and the $\texttt{ISTA}$ (Iterative Shrinkage-Thresholding Algorithms) block, i.e., 42 | $$\boldsymbol{Z}^{\ell+1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).$$ 43 | 44 | ### 1. CRATE Architecture overview 45 | 46 | The following figure presents an overview of the pipeline for our proposed **CRATE** architecture: 47 | 48 |
49 |
50 |
52 | 53 | ### 2. One layer/block of CRATE 54 | 55 | The following figure shows the overall architecture of one layer of **CRATE** as the composition of $\texttt{MSSA}$ and $\texttt{ISTA}$ blocks. 56 | 57 |
58 |
59 |
61 | 62 | ### 3. Per-layer optimization in CRATE 63 | 64 | In the following figure, we measure the compression term [ $R^{c}$ ($\boldsymbol{Z}^{\ell+1/2}$) ] and the sparsity term [ $||\boldsymbol{Z}^{\ell+1}||_0$ ] defined in the **sparse rate reduction objective**, and we find that each layer of **CRATE** indeed optimizes the targeted objectives, showing that our white-box theoretical design is predictive of practice. 65 |
66 |
67 |
69 | 70 | ### 4. Segmentation visualization of CRATE 71 | In the following figure, we visualize self-attention maps from a supervised **CRATE** model with 8x8 patches (similar to the ones shown in [DINO](https://github.com/facebookresearch/dino) :t-rex:). 72 |
73 |
74 |
76 | 77 | We also discover a surprising empirical phenomenon where each attention head in **CRATE** retains its own semantics. 78 |
79 |
80 |
82 | 83 | 84 | ## Autoencoding 85 | 86 | We can also use our theory to build a principled autoencoder, which has the following architecture. 87 |
88 |
89 |
91 | 92 | It has many of the same empirical properties as the base **CRATE** model, such as segmented attention maps and amenability to layer-wise analysis. We train it on the masked autoencoding task (calling this model **CRATE-MAE**), and it achieves comparable performance in linear probing and reconstruction quality as the base ViT-MAE. 93 | 94 |
95 |
96 |
98 | 99 | 100 | # Implementation and Experiments 101 | 102 | ## Constructing a CRATE model 103 | A CRATE model can be defined using the following code, (the below parameters are specified for CRATE-Tiny) 104 | ```python 105 | from model.crate import CRATE 106 | dim = 384 107 | n_heads = 6 108 | depth = 12 109 | model = CRATE(image_size=224, 110 | patch_size=16, 111 | num_classes=1000, 112 | dim=dim, 113 | depth=depth, 114 | heads=n_heads, 115 | dim_head=dim // n_heads) 116 | ``` 117 | 118 | ### Pre-trained Checkpoints (ImageNet-1K) 119 | | model | `dim` | `n_heads` | `depth` | pre-trained checkpoint | 120 | | -------- | -------- | -------- | -------- | -------- | 121 | | **CRATE-T**(iny) | 384 | 6 | 12 | TODO | 122 | | **CRATE-S**(mall) | 576 | 12 | 12 | [download link](https://drive.google.com/file/d/1hYgDJl4EKHYfKprwhEjmWmWHuxnK6_h8/view?usp=share_link) | 123 | | **CRATE-B**(ase) | 768 | 12 | 12 | TODO | 124 | | **CRATE-L**(arge) | 1024 | 16 | 24 | TODO | 125 | 126 | ## Training CRATE on ImageNet 127 | To train a CRATE model on ImageNet-1K, run the following script (training CRATE-tiny) 128 | 129 | As an example, we use the following command for training CRATE-tiny on ImageNet-1K: 130 | ```python 131 | python main.py 132 | --arch CRATE_tiny 133 | --batch-size 512 134 | --epochs 200 135 | --optimizer Lion 136 | --lr 0.0002 137 | --weight-decay 0.05 138 | --print-freq 25 139 | --data DATA_DIR 140 | ``` 141 | and replace `DATA_DIR` with `[imagenet-folder with train and val folders]`. 142 | 143 | 144 | ## Finetuning pretrained / training random initialized CRATE on CIFAR10 145 | 146 | ```python 147 | python finetune.py 148 | --bs 256 149 | --net CRATE_tiny 150 | --opt adamW 151 | --lr 5e-5 152 | --n_epochs 200 153 | --randomaug 1 154 | --data cifar10 155 | --ckpt_dir CKPT_DIR 156 | --data_dir DATA_DIR 157 | ``` 158 | Replace `CKPT_DIR` with the path for the pretrained CRATE weight, and replace `DATA_DIR` with the path for the `CIFAR10` dataset. If `CKPT_DIR` is `None`, then this script is for training CRATE from random initialization on CIFAR10. 159 | 160 | ## Demo: Emergent segmentation in CRATE 161 | 162 | CRATE models exhibit emergent segmentation in their self-attention maps solely through supervised training. 163 | We provide a Colab Jupyter notebook to visualize the emerged segmentations from a supervised **CRATE** model. The demo provides visualizations which match the segmentation figures above. 164 | 165 | Link: [crate-emergence.ipynb](https://colab.research.google.com/drive/1rYn_NlepyW7Fu5LDliyBDmFZylHco7ss?usp=sharing) (in colab) 166 | 167 |
168 |
169 |
171 | 172 | ## Constructing a CRATE autoencoding model 173 | A CRATE-autoencoding model (specifically **CRATE-MAE-Base**) can be defined using the following code: 174 | ```python 175 | from model.crate_ae.crate_ae import mae_crate_base 176 | model = mae_crate_base() 177 | ``` 178 | The other sizes in the paper are also importable in that way. Modifying the `model/crate_ae/crate_ae.py` file will let you initialize and serve your own config. 179 | 180 | ### Pre-trained Checkpoints (ImageNet-1K) 181 | | model | `dim` | `n_heads` | `depth` | pre-trained checkpoint | 182 | | -------- | -------- | -------- | -------- | -------- | 183 | | **CRATE-MAE-S**(mall) | 576 | 12 | 12 | TODO | 184 | | **CRATE-MAE-B**(ase) | 768 | 12 | 12 | [link](https://drive.google.com/file/d/11i5BMwymqOsunq44WD3omN5mS6ZREQPO/view?usp=sharing) | 185 | 186 | ## Training/Fine-Tuning CRATE-MAE 187 | To train or fine-tune a CRATE-MAE model on ImageNet-1K, please refer to the [codebase on MAE training](https://github.com/facebookresearch/mae) from Meta FAIR. The `models_mae.py` file in that codebase can be replaced with the contents of `model/crate_ae/crate_ae.py`, and the rest of the code should go through with minimal alterations. 188 | 189 | 190 | ## Demo: Emergent segmentation in CRATE-MAE 191 | 192 | CRATE-MAE models also exhibit emergent segmentation in their self-attention maps. 193 | We provide a Colab Jupyter notebook to visualize the emerged segmentations from a **CRATE-MAE** model. The demo provides visualizations which match the segmentation figures above. 194 | 195 | Link: [crate-mae.ipynb](https://colab.research.google.com/drive/1xcD-xcxprfgZuvwsRKuDroH7xMjr0Ad3?usp=sharing) (in colab) 196 | 197 | # Reference 198 | For technical details and full experimental results, please check the [CRATE paper](https://arxiv.org/abs/2306.01129), [CRATE segmentation paper](https://arxiv.org/abs/2308.16271), [CRATE autoencoding paper](https://openreview.net/forum?id=PvyOYleymy), or [the long-form overview paper](https://arxiv.org/abs/2311.13110). Please consider citing our work if you find it helpful to yours: 199 | 200 | ``` 201 | @article{yu2024white, 202 | title={White-Box Transformers via Sparse Rate Reduction}, 203 | author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Haeffele, Benjamin and Ma, Yi}, 204 | journal={Advances in Neural Information Processing Systems}, 205 | volume={36}, 206 | year={2024} 207 | } 208 | ``` 209 | ``` 210 | @inproceedings{yu2024emergence, 211 | title={Emergence of Segmentation with Minimalistic White-Box Transformers}, 212 | author={Yu, Yaodong and Chu, Tianzhe and Tong, Shengbang and Wu, Ziyang and Pai, Druv and Buchanan, Sam and Ma, Yi}, 213 | booktitle={Conference on Parsimony and Learning}, 214 | pages={72--93}, 215 | year={2024}, 216 | organization={PMLR} 217 | } 218 | ``` 219 | ``` 220 | @inproceedings{pai2024masked, 221 | title={Masked Completion via Structured Diffusion with White-Box Transformers}, 222 | author={Pai, Druv and Buchanan, Sam and Wu, Ziyang and Yu, Yaodong and Ma, Yi}, 223 | booktitle={The Twelfth International Conference on Learning Representations}, 224 | year={2024} 225 | } 226 | ``` 227 | ``` 228 | @article{yu2023white, 229 | title={White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?}, 230 | author={Yu, Yaodong and Buchanan, Sam and Pai, Druv and Chu, Tianzhe and Wu, Ziyang and Tong, Shengbang and Bai, Hao and Zhai, Yuexiang and Haeffele, Benjamin D and Ma, Yi}, 231 | journal={arXiv preprint arXiv:2311.13110}, 232 | year={2023} 233 | } 234 | ``` 235 | -------------------------------------------------------------------------------- /crate_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/crate_figure.png -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as transforms 3 | 4 | 5 | 6 | def load_dataset(data, size, transform_train, transform_test, data_dir=None): 7 | if data_dir is None: 8 | data_dir = "../" + data 9 | if data == "cifar10": 10 | trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform_train) 11 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8) 12 | 13 | testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform_test) 14 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 15 | elif data == "cifar100": 16 | trainset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=transform_train) 17 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8) 18 | 19 | testset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform_test) 20 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 21 | elif data == "flower": 22 | trainset = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform_train) 23 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8) 24 | 25 | testset = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform_test) 26 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 27 | elif data == "pets": 28 | trainset = torchvision.datasets.OxfordIIITPet(root=data_dir, split="trainval", download=True, transform=transform_train) 29 | # trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8) 30 | 31 | testset = torchvision.datasets.OxfordIIITPet(root=data_dir, split="test", download=True, transform=transform_test) 32 | # testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 33 | 34 | return trainset, testset -------------------------------------------------------------------------------- /data/randomaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = self.eigvec.type_as(img).clone() \ 220 | .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | .sum(1).squeeze() 223 | 224 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | class CutoutDefault(object): 228 | """ 229 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | """ 231 | def __init__(self, length): 232 | self.length = length 233 | 234 | def __call__(self, img): 235 | h, w = img.size(1), img.size(2) 236 | mask = np.ones((h, w), np.float32) 237 | y = np.random.randint(h) 238 | x = np.random.randint(w) 239 | 240 | y1 = np.clip(y - self.length // 2, 0, h) 241 | y2 = np.clip(y + self.length // 2, 0, h) 242 | x1 = np.clip(x - self.length // 2, 0, w) 243 | x2 = np.clip(x + self.length // 2, 0, w) 244 | 245 | mask[y1: y2, x1: x2] = 0. 246 | mask = torch.from_numpy(mask) 247 | mask = mask.expand_as(img) 248 | img *= mask 249 | return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | ops = random.choices(self.augment_list, k=self.n) 260 | for op, minval, maxval in ops: 261 | val = (float(self.m) / 30) * float(maxval - minval) + minval 262 | img = op(img, val) 263 | 264 | return img -------------------------------------------------------------------------------- /figs/CRATE_arch.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/CRATE_arch.pdf -------------------------------------------------------------------------------- /figs/CRATE_fig1_patches.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/CRATE_fig1_patches.pdf -------------------------------------------------------------------------------- /figs/fig_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_arch.png -------------------------------------------------------------------------------- /figs/fig_arch_autoencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_arch_autoencoder.png -------------------------------------------------------------------------------- /figs/fig_layerwise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_layerwise.png -------------------------------------------------------------------------------- /figs/fig_masked_reconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_masked_reconstruction.png -------------------------------------------------------------------------------- /figs/fig_objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_objective.png -------------------------------------------------------------------------------- /figs/fig_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_pipeline.png -------------------------------------------------------------------------------- /figs/fig_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_seg.png -------------------------------------------------------------------------------- /figs/fig_seg_headwise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ma-Lab-Berkeley/CRATE/674408fa82475fe1f172aa8213e21d4ba608afc4/figs/fig_seg_headwise.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | import numpy as np 10 | 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | 14 | import os 15 | import argparse 16 | import csv 17 | import time 18 | 19 | from utils import progress_bar 20 | from data.randomaug import RandAugment 21 | from model.crate import * 22 | from model.vit import * 23 | from data.dataset import * 24 | # parsers 25 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 26 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4 27 | parser.add_argument('--opt', default="adamW") 28 | parser.add_argument('--net', default='vit') 29 | parser.add_argument('--bs', type=int, default=50) 30 | parser.add_argument('--data', default="cifar10") 31 | parser.add_argument('--classes',type=int, default=10) 32 | parser.add_argument('--resume',type=int, default=0) 33 | parser.add_argument('--randomaug',type=int, default=1) 34 | parser.add_argument('--rand_aug_n',type=int, default=2) 35 | parser.add_argument('--rand_aug_m',type=int, default=14) 36 | parser.add_argument('--erase_prob',type=float, default=0.0) 37 | parser.add_argument('--n_epochs', type=int, default='400') 38 | parser.add_argument('--patch', default='4', type=int, help="patch for ViT") 39 | parser.add_argument('--ckpt_dir', type=str, default=None,help='location for the pretrained CRATE weight') 40 | parser.add_argument('--data_dir', type=str, default='./data',help='location for datasets') 41 | 42 | args = parser.parse_args() 43 | 44 | # take in args 45 | 46 | use_amp = True 47 | bs = args.bs 48 | 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | best_acc = 0 # best test accuracy 51 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 52 | 53 | # Data 54 | print('==> Preparing data..') 55 | size=224 56 | transform_train = transforms.Compose([ 57 | transforms.RandomResizedCrop((size,size)), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.RandAugment(args.rand_aug_n, args.rand_aug_m) if args.randomaug else transforms.TrivialAugmentWide(), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 62 | transforms.RandomErasing(p=args.erase_prob), 63 | ]) 64 | print("size", size) 65 | transform_test = transforms.Compose([ 66 | transforms.Resize((size,size)), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 69 | ]) 70 | 71 | transet, testset = load_dataset(args.data, size=size, transform_train=transform_train, transform_test=transform_test, data_dir=args.data_dir) 72 | trainloader = torch.utils.data.DataLoader(transet, batch_size=bs, shuffle=True, num_workers=8) 73 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 74 | # Model factory.. 75 | print('==> Building model..') 76 | if args.ckpt_dir is None: 77 | print("Train from scratch.") 78 | if args.net == 'vit_tiny': 79 | net = vit_tiny_patch16(global_pool=True) 80 | net.head = nn.Linear(192, args.classes) 81 | elif args.net == 'vit_small': 82 | net = vit_small_patch16(global_pool=True) 83 | net.head = nn.Linear(384, args.classes) 84 | elif args.net == 'CRATE_tiny': 85 | net = CRATE_tiny(args.classes) 86 | elif args.net == "CRATE_small": 87 | net = CRATE_small(args.classes) 88 | elif args.net == "CRATE_base": 89 | net = CRATE_base(args.classes) 90 | elif args.net == "CRATE_large": 91 | net = CRATE_large(args.classes) 92 | 93 | # For Multi-GPU 94 | if 'cuda' in device: 95 | print(device) 96 | print("using data parallel") 97 | net = torch.nn.DataParallel(net) # make parallel 98 | if args.ckpt_dir is not None: 99 | #upd keys 100 | state_dict = torch.load(args.ckpt_dir)['state_dict'] 101 | for key in list(state_dict.keys()): 102 | if 'mlp_head' in key: 103 | del state_dict[key] 104 | print("deleted:", key) 105 | net.load_state_dict(state_dict, strict=False) 106 | cudnn.benchmark = True 107 | if args.resume: 108 | # Load checkpoint. 109 | print('==> Resuming from checkpoint..') 110 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 111 | checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net)) 112 | net.load_state_dict(checkpoint['net']) 113 | best_acc = checkpoint['acc'] 114 | start_epoch = checkpoint['epoch'] 115 | 116 | # Loss is CE 117 | criterion = nn.CrossEntropyLoss() 118 | 119 | if args.opt == "adam": 120 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 121 | elif args.opt == "sgd": 122 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) 123 | elif args.opt == "adamW": 124 | print("using adamW") 125 | optimizer = optim.AdamW(net.parameters(), lr=args.lr, weight_decay=0.01, betas = (0.9, 0.999), eps = 1e-8) 126 | # use cosine scheduling 127 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs) 128 | 129 | ##### Training 130 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 131 | def train(epoch): 132 | print('\nEpoch: %d' % epoch) 133 | net.train() 134 | train_loss = 0 135 | correct = 0 136 | total = 0 137 | for batch_idx, (inputs, targets) in enumerate(trainloader): 138 | inputs, targets = inputs.to(device), targets.to(device) 139 | # Train with amp 140 | with torch.cuda.amp.autocast(enabled=use_amp): 141 | outputs = net(inputs) 142 | loss = criterion(outputs, targets) 143 | scaler.scale(loss).backward() 144 | scaler.step(optimizer) 145 | scaler.update() 146 | optimizer.zero_grad() 147 | 148 | train_loss += loss.item() 149 | _, predicted = outputs.max(1) 150 | total += targets.size(0) 151 | correct += predicted.eq(targets).sum().item() 152 | 153 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 154 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 155 | return train_loss/(batch_idx+1) 156 | 157 | ##### Validation 158 | def test(epoch): 159 | global best_acc 160 | net.eval() 161 | test_loss = 0 162 | correct = 0 163 | total = 0 164 | with torch.no_grad(): 165 | for batch_idx, (inputs, targets) in enumerate(testloader): 166 | inputs, targets = inputs.to(device), targets.to(device) 167 | outputs = net(inputs) 168 | loss = criterion(outputs, targets) 169 | 170 | test_loss += loss.item() 171 | _, predicted = outputs.max(1) 172 | total += targets.size(0) 173 | correct += predicted.eq(targets).sum().item() 174 | 175 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 176 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 177 | 178 | # Save checkpoint. 179 | acc = 100.*correct/total 180 | if acc > best_acc: 181 | print('Saving..') 182 | state = {"model": net.state_dict(), 183 | "optimizer": optimizer.state_dict(), 184 | "scaler": scaler.state_dict()} 185 | if not os.path.isdir('checkpoint'): 186 | os.mkdir('checkpoint') 187 | torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch)) 188 | best_acc = acc 189 | 190 | os.makedirs("log", exist_ok=True) 191 | content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}' 192 | print(content) 193 | with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender: 194 | appender.write(content + "\n") 195 | return test_loss, acc 196 | 197 | list_loss = [] 198 | list_acc = [] 199 | 200 | net.cuda() 201 | for epoch in range(start_epoch, args.n_epochs): 202 | start = time.time() 203 | trainloss = train(epoch) 204 | val_loss, acc = test(epoch) 205 | 206 | scheduler.step(epoch-1) # step cosine scheduling 207 | 208 | list_loss.append(val_loss) 209 | list_acc.append(acc) 210 | 211 | # Write out csv.. 212 | with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f: 213 | writer = csv.writer(f, lineterminator='\n') 214 | writer.writerow(list_loss) 215 | writer.writerow(list_acc) 216 | print(list_loss) 217 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | from enum import Enum 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import torchvision.datasets as datasets 14 | import torchvision.transforms as transforms 15 | from torch.utils.data import Subset 16 | import math 17 | from model.crate import * 18 | from model.vit import * 19 | from timm.loss.cross_entropy import LabelSmoothingCrossEntropy 20 | from lion_pytorch import Lion 21 | 22 | model_names = ["vit_tiny", "vit_small", "CRATE_tiny", "CRATE_small", "CRATE_base", "CRATE_large"] 23 | 24 | def get_args_parser(): 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('--data', metavar='DIR', default="/path/to/imagenet", 28 | help='path to dataset (default: imagenet)') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='CRATE_tiny', 30 | choices=model_names, 31 | help='model architecture: ' + 32 | ' | '.join(model_names) + 33 | ' (default: CRATE_tiny)') 34 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('--label_smooth', default=0.1, type=float, metavar='L', 41 | help='label smoothing coef') 42 | parser.add_argument('-b', '--batch-size', default=256, type=int, 43 | metavar='N', 44 | help='mini-batch size (default: 256), this is the total ' 45 | 'batch size of all GPUs on the current node when ' 46 | 'using Data Parallel or Distributed Data Parallel') 47 | parser.add_argument('--lr', '--learning-rate', default=0.0004, type=float, 48 | metavar='LR', help='initial learning rate', dest='lr') 49 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 50 | help='momentum') 51 | parser.add_argument('--wd', '--weight-decay', default=0.1, type=float, 52 | metavar='W', help='weight decay (default: 1e-4)', 53 | dest='weight_decay') 54 | parser.add_argument('-p', '--print-freq', default=10, type=int, 55 | metavar='N', help='print frequency (default: 10)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 59 | help='evaluate model on validation set') 60 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 61 | help='use pre-trained model') 62 | parser.add_argument('--world-size', default=-1, type=int, 63 | help='number of nodes for distributed training') 64 | parser.add_argument('--rank', default=-1, type=int, 65 | help='node rank for distributed training') 66 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 67 | help='url used to set up distributed training') 68 | parser.add_argument('--dist-backend', default='nccl', type=str, 69 | help='distributed backend') 70 | parser.add_argument('--seed', default=None, type=int, 71 | help='seed for initializing training. ') 72 | parser.add_argument('--gpu', default=None, type=int, 73 | help='GPU id to use.') 74 | parser.add_argument('--optimizer', default="AdamW", type=str, 75 | help='Optimizer to Use.') 76 | parser.add_argument('--multiprocessing-distributed', action='store_true', 77 | help='Use multi-processing distributed training to launch ' 78 | 'N processes per node, which has N GPUs. This is the ' 79 | 'fastest way to use PyTorch for either single node or ' 80 | 'multi node data parallel training') 81 | parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark") 82 | return parser 83 | 84 | parser = get_args_parser() 85 | best_acc1 = 0 86 | 87 | from torch.cuda.amp import autocast, GradScaler 88 | scaler = GradScaler() 89 | 90 | def main(): 91 | args = parser.parse_args() 92 | 93 | if args.seed is not None: 94 | random.seed(args.seed) 95 | torch.manual_seed(args.seed) 96 | cudnn.deterministic = True 97 | cudnn.benchmark = False 98 | warnings.warn('You have chosen to seed training. ' 99 | 'This will turn on the CUDNN deterministic setting, ' 100 | 'which can slow down your training considerably! ' 101 | 'You may see unexpected behavior when restarting ' 102 | 'from checkpoints.') 103 | 104 | if args.gpu is not None: 105 | warnings.warn('You have chosen a specific GPU. This will completely ' 106 | 'disable data parallelism.') 107 | 108 | if args.dist_url == "env://" and args.world_size == -1: 109 | args.world_size = int(os.environ["WORLD_SIZE"]) 110 | 111 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 112 | 113 | if torch.cuda.is_available(): 114 | ngpus_per_node = torch.cuda.device_count() 115 | else: 116 | ngpus_per_node = 1 117 | if args.multiprocessing_distributed: 118 | # Since we have ngpus_per_node processes per node, the total world_size 119 | # needs to be adjusted accordingly 120 | args.world_size = ngpus_per_node * args.world_size 121 | # Use torch.multiprocessing.spawn to launch distributed processes: the 122 | # main_worker process function 123 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 124 | else: 125 | # Simply call main_worker function 126 | main_worker(args.gpu, ngpus_per_node, args) 127 | 128 | 129 | 130 | def main_worker(gpu, ngpus_per_node, args): 131 | global best_acc1 132 | args.gpu = gpu 133 | 134 | if args.gpu is not None: 135 | print("Use GPU: {} for training".format(args.gpu)) 136 | 137 | if args.distributed: 138 | if args.dist_url == "env://" and args.rank == -1: 139 | args.rank = int(os.environ["RANK"]) 140 | if args.multiprocessing_distributed: 141 | # For multiprocessing distributed training, rank needs to be the 142 | # global rank among all the processes 143 | args.rank = args.rank * ngpus_per_node + gpu 144 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 145 | world_size=args.world_size, rank=args.rank) 146 | 147 | print('==> Building model: {}'.format(args.arch)) 148 | if args.arch == 'vit_tiny': 149 | model = vit_tiny_patch16(global_pool=True) 150 | elif args.arch == 'vit_small': 151 | model = vit_small_patch16(global_pool=True) 152 | elif args.arch == 'CRATE_tiny': 153 | model = CRATE_tiny() 154 | elif args.arch == "CRATE_small": 155 | model = CRATE_small() 156 | elif args.arch == "CRATE_base": 157 | model = CRATE_base() 158 | elif args.arch == "CRATE_large": 159 | model = CRATE_large() 160 | else: 161 | raise NotImplementedError 162 | 163 | if not torch.cuda.is_available() and not torch.backends.mps.is_available(): 164 | print('using CPU, this will be slow') 165 | elif args.distributed: 166 | # For multiprocessing distributed, DistributedDataParallel constructor 167 | # should always set the single device scope, otherwise, 168 | # DistributedDataParallel will use all available devices. 169 | if torch.cuda.is_available(): 170 | if args.gpu is not None: 171 | torch.cuda.set_device(args.gpu) 172 | model.cuda(args.gpu) 173 | # When using a single GPU per process and per 174 | # DistributedDataParallel, we need to divide the batch size 175 | # ourselves based on the total number of GPUs of the current node. 176 | args.batch_size = int(args.batch_size / ngpus_per_node) 177 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 178 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 179 | else: 180 | model.cuda() 181 | # DistributedDataParallel will divide and allocate batch_size to all 182 | # available GPUs if device_ids are not set 183 | model = torch.nn.parallel.DistributedDataParallel(model) 184 | elif args.gpu is not None and torch.cuda.is_available(): 185 | torch.cuda.set_device(args.gpu) 186 | model = model.cuda(args.gpu) 187 | else: 188 | # DataParallel will divide and allocate batch_size to all available GPUs 189 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 190 | model.features = torch.nn.DataParallel(model.features) 191 | model.cuda() 192 | else: 193 | model = torch.nn.DataParallel(model).cuda() 194 | 195 | if torch.cuda.is_available(): 196 | if args.gpu: 197 | device = torch.device('cuda:{}'.format(args.gpu)) 198 | else: 199 | device = torch.device("cuda") 200 | else: 201 | device = torch.device("cpu") 202 | # define loss function (criterion), optimizer, and learning rate scheduler 203 | criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smooth).to(device) 204 | 205 | if args.optimizer == "AdamW": 206 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, 207 | betas=(0.9, 0.999), 208 | weight_decay=args.weight_decay) 209 | elif args.optimizer == "Lion": 210 | optimizer = Lion(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 211 | else: 212 | raise NotImplementedError 213 | 214 | warmup_steps = 20 215 | lr_func = lambda step: min((step + 1) / (warmup_steps + 1e-8), 216 | 0.5 * (math.cos(step / args.epochs * math.pi) + 1)) 217 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func, verbose=True) 218 | 219 | # optionally resume from a checkpoint 220 | if args.resume: 221 | if os.path.isfile(args.resume): 222 | print("=> loading checkpoint '{}'".format(args.resume)) 223 | if args.gpu is None: 224 | checkpoint = torch.load(args.resume) 225 | elif torch.cuda.is_available(): 226 | # Map model to be loaded to specified single gpu. 227 | loc = 'cuda:{}'.format(args.gpu) 228 | checkpoint = torch.load(args.resume, map_location=loc) 229 | args.start_epoch = checkpoint['epoch'] 230 | best_acc1 = checkpoint['best_acc1'] 231 | if args.gpu is not None: 232 | # best_acc1 may be from a checkpoint from a different GPU 233 | best_acc1 = best_acc1.to(args.gpu) 234 | model.load_state_dict(checkpoint['state_dict']) 235 | optimizer.load_state_dict(checkpoint['optimizer']) 236 | scheduler.load_state_dict(checkpoint['scheduler']) 237 | print("=> loaded checkpoint '{}' (epoch {})" 238 | .format(args.resume, checkpoint['epoch'])) 239 | else: 240 | print("=> no checkpoint found at '{}'".format(args.resume)) 241 | 242 | 243 | # Data loading code 244 | if args.dummy: 245 | print("=> Dummy data is used!") 246 | train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor()) 247 | val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor()) 248 | else: 249 | traindir = os.path.join(args.data, 'train') 250 | valdir = os.path.join(args.data, 'val') 251 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 252 | std=[0.229, 0.224, 0.225]) 253 | 254 | transform_simple = transforms.Compose([ 255 | transforms.RandomResizedCrop(224), 256 | transforms.RandomHorizontalFlip(), 257 | transforms.ToTensor(), 258 | normalize, 259 | ]) 260 | 261 | train_dataset = datasets.ImageFolder( 262 | traindir, 263 | transform_simple 264 | ) 265 | 266 | val_dataset = datasets.ImageFolder( 267 | valdir, 268 | transforms.Compose([ 269 | transforms.Resize(256), 270 | transforms.CenterCrop(224), 271 | transforms.ToTensor(), 272 | normalize, 273 | ])) 274 | 275 | if args.distributed: 276 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 277 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) 278 | else: 279 | train_sampler = None 280 | val_sampler = None 281 | print(f"I am using {args.workers} worker") 282 | train_loader = torch.utils.data.DataLoader( 283 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 284 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 285 | 286 | val_loader = torch.utils.data.DataLoader( 287 | val_dataset, batch_size=args.batch_size, shuffle=False, 288 | num_workers=args.workers, pin_memory=True, sampler=val_sampler) 289 | 290 | if args.evaluate: 291 | validate(val_loader, model, criterion, args) 292 | return 293 | 294 | for epoch in range(args.start_epoch, args.epochs): 295 | if args.distributed: 296 | print('distributed loader') 297 | train_sampler.set_epoch(epoch) 298 | else: 299 | print('non-distributed loader') 300 | 301 | # train for one epoch 302 | train(train_loader, model, criterion, optimizer, epoch, device, args) 303 | 304 | # evaluate on validation set 305 | acc1 = validate(val_loader, model, criterion, args) 306 | 307 | scheduler.step() 308 | 309 | # remember best acc@1 and save checkpoint 310 | is_best = acc1 > best_acc1 311 | best_acc1 = max(acc1, best_acc1) 312 | 313 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 314 | and args.rank % ngpus_per_node == 0): 315 | save_checkpoint({ 316 | 'epoch': epoch + 1, 317 | 'arch': args.arch, 318 | 'state_dict': model.state_dict(), 319 | 'best_acc1': best_acc1, 320 | 'optimizer' : optimizer.state_dict(), 321 | 'scheduler' : scheduler.state_dict() 322 | }, is_best) 323 | 324 | grad_clip_norm = 1.0 325 | def train(train_loader, model, criterion, optimizer, epoch, device, args): 326 | batch_time = AverageMeter('Time', ':6.3f') 327 | data_time = AverageMeter('Data', ':6.3f') 328 | losses = AverageMeter('Loss', ':.4e') 329 | top1 = AverageMeter('Acc@1', ':6.2f') 330 | top5 = AverageMeter('Acc@5', ':6.2f') 331 | progress = ProgressMeter( 332 | len(train_loader), 333 | [batch_time, data_time, losses, top1, top5], 334 | prefix="Epoch: [{}]".format(epoch)) 335 | 336 | # switch to train mode 337 | model.train() 338 | 339 | end = time.time() 340 | for i, (images, target) in enumerate(train_loader): 341 | # measure data loading time 342 | data_time.update(time.time() - end) 343 | 344 | # move data to the same device as model 345 | images = images.to(device, non_blocking=True) 346 | target = target.to(device, non_blocking=True) 347 | 348 | # compute output 349 | with autocast(): 350 | output = model(images) 351 | loss = criterion(output, target) 352 | 353 | # measure accuracy and record loss 354 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 355 | losses.update(loss.item(), images.size(0)) 356 | top1.update(acc1[0], images.size(0)) 357 | top5.update(acc5[0], images.size(0)) 358 | 359 | # compute gradient and do SGD step 360 | optimizer.zero_grad() 361 | scaler.scale(loss).backward() 362 | 363 | scaler.step(optimizer) 364 | scaler.update() 365 | 366 | # measure elapsed time 367 | batch_time.update(time.time() - end) 368 | end = time.time() 369 | 370 | if i % args.print_freq == 0: 371 | progress.display(i + 1) 372 | 373 | 374 | def validate(val_loader, model, criterion, args): 375 | 376 | def run_validate(loader, base_progress=0): 377 | with torch.no_grad(): 378 | end = time.time() 379 | for i, (images, target) in enumerate(loader): 380 | i = base_progress + i 381 | if args.gpu is not None and torch.cuda.is_available(): 382 | images = images.cuda(args.gpu, non_blocking=True) 383 | if torch.cuda.is_available(): 384 | target = target.cuda(args.gpu, non_blocking=True) 385 | 386 | # compute output 387 | output = model(images) 388 | loss = criterion(output, target) 389 | 390 | # measure accuracy and record loss 391 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 392 | losses.update(loss.item(), images.size(0)) 393 | top1.update(acc1[0], images.size(0)) 394 | top5.update(acc5[0], images.size(0)) 395 | 396 | # measure elapsed time 397 | batch_time.update(time.time() - end) 398 | end = time.time() 399 | 400 | if i % args.print_freq == 0: 401 | progress.display(i + 1) 402 | 403 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 404 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 405 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 406 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 407 | progress = ProgressMeter( 408 | len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), 409 | [batch_time, losses, top1, top5], 410 | prefix='Test: ') 411 | 412 | # switch to evaluate mode 413 | model.eval() 414 | 415 | run_validate(val_loader) 416 | if args.distributed: 417 | top1.all_reduce() 418 | top5.all_reduce() 419 | 420 | if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): 421 | aux_val_dataset = Subset(val_loader.dataset, 422 | range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset))) 423 | aux_val_loader = torch.utils.data.DataLoader( 424 | aux_val_dataset, batch_size=args.batch_size, shuffle=False, 425 | num_workers=args.workers, pin_memory=True) 426 | run_validate(aux_val_loader, len(val_loader)) 427 | 428 | progress.display_summary() 429 | 430 | return top1.avg 431 | 432 | 433 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 434 | torch.save(state, filename) 435 | if is_best: 436 | shutil.copyfile(filename, 'model_best.pth.tar') 437 | 438 | class Summary(Enum): 439 | NONE = 0 440 | AVERAGE = 1 441 | SUM = 2 442 | COUNT = 3 443 | 444 | class AverageMeter(object): 445 | """Computes and stores the average and current value""" 446 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 447 | self.name = name 448 | self.fmt = fmt 449 | self.summary_type = summary_type 450 | self.reset() 451 | 452 | def reset(self): 453 | self.val = 0 454 | self.avg = 0 455 | self.sum = 0 456 | self.count = 0 457 | 458 | def update(self, val, n=1): 459 | self.val = val 460 | self.sum += val * n 461 | self.count += n 462 | self.avg = self.sum / self.count 463 | 464 | def all_reduce(self): 465 | if torch.cuda.is_available(): 466 | device = torch.device("cuda") 467 | else: 468 | device = torch.device("cpu") 469 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 470 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 471 | self.sum, self.count = total.tolist() 472 | self.avg = self.sum / self.count 473 | 474 | def __str__(self): 475 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 476 | return fmtstr.format(**self.__dict__) 477 | 478 | def summary(self): 479 | fmtstr = '' 480 | if self.summary_type is Summary.NONE: 481 | fmtstr = '' 482 | elif self.summary_type is Summary.AVERAGE: 483 | fmtstr = '{name} {avg:.3f}' 484 | elif self.summary_type is Summary.SUM: 485 | fmtstr = '{name} {sum:.3f}' 486 | elif self.summary_type is Summary.COUNT: 487 | fmtstr = '{name} {count:.3f}' 488 | else: 489 | raise ValueError('invalid summary type %r' % self.summary_type) 490 | 491 | return fmtstr.format(**self.__dict__) 492 | 493 | 494 | class ProgressMeter(object): 495 | def __init__(self, num_batches, meters, prefix=""): 496 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 497 | self.meters = meters 498 | self.prefix = prefix 499 | 500 | def display(self, batch): 501 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 502 | entries += [str(meter) for meter in self.meters] 503 | print('\t'.join(entries)) 504 | 505 | def display_summary(self): 506 | entries = [" *"] 507 | entries += [meter.summary() for meter in self.meters] 508 | print(' '.join(entries)) 509 | 510 | def _get_batch_fmtstr(self, num_batches): 511 | num_digits = len(str(num_batches // 1)) 512 | fmt = '{:' + str(num_digits) + 'd}' 513 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 514 | 515 | def accuracy(output, target, topk=(1,)): 516 | """Computes the accuracy over the k top predictions for the specified values of k""" 517 | with torch.no_grad(): 518 | maxk = max(topk) 519 | batch_size = target.size(0) 520 | 521 | _, pred = output.topk(maxk, 1, True, True) 522 | pred = pred.t() 523 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 524 | 525 | res = [] 526 | for k in topk: 527 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 528 | res.append(correct_k.mul_(100.0 / batch_size)) 529 | return res 530 | 531 | 532 | if __name__ == '__main__': 533 | main() 534 | -------------------------------------------------------------------------------- /model/crate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | 24 | class FeedForward(nn.Module): 25 | def __init__(self, dim, hidden_dim, dropout=0., step_size=0.1): 26 | super().__init__() 27 | self.weight = nn.Parameter(torch.Tensor(dim, dim)) 28 | with torch.no_grad(): 29 | init.kaiming_uniform_(self.weight) 30 | self.step_size = step_size 31 | self.lambd = 0.1 32 | 33 | def forward(self, x): 34 | # compute D^T * D * x 35 | x1 = F.linear(x, self.weight, bias=None) 36 | grad_1 = F.linear(x1, self.weight.t(), bias=None) 37 | # compute D^T * x 38 | grad_2 = F.linear(x, self.weight.t(), bias=None) 39 | # compute negative gradient update: step_size * (D^T * x - D^T * D * x) 40 | grad_update = self.step_size * (grad_2 - grad_1) - self.step_size * self.lambd 41 | 42 | output = F.relu(x + grad_update) 43 | return output 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 48 | super().__init__() 49 | inner_dim = dim_head * heads 50 | project_out = not (heads == 1 and dim_head == dim) 51 | 52 | self.heads = heads 53 | self.scale = dim_head ** -0.5 54 | 55 | self.attend = nn.Softmax(dim=-1) 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | self.qkv = nn.Linear(dim, inner_dim, bias=False) 59 | 60 | self.to_out = nn.Sequential( 61 | nn.Linear(inner_dim, dim), 62 | nn.Dropout(dropout) 63 | ) if project_out else nn.Identity() 64 | 65 | def forward(self, x): 66 | w = rearrange(self.qkv(x), 'b n (h d) -> b h n d', h=self.heads) 67 | 68 | dots = torch.matmul(w, w.transpose(-1, -2)) * self.scale 69 | 70 | attn = self.attend(dots) 71 | attn = self.dropout(attn) 72 | 73 | out = torch.matmul(attn, w) 74 | 75 | out = rearrange(out, 'b h n d -> b n (h d)') 76 | return self.to_out(out) 77 | 78 | 79 | class Transformer(nn.Module): 80 | def __init__(self, dim, depth, heads, dim_head, dropout=0., ista=0.1): 81 | super().__init__() 82 | self.layers = nn.ModuleList([]) 83 | self.heads = heads 84 | self.depth = depth 85 | self.dim = dim 86 | for _ in range(depth): 87 | self.layers.append( 88 | nn.ModuleList( 89 | [ 90 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 91 | PreNorm(dim, FeedForward(dim, dim, dropout=dropout, step_size=ista)) 92 | ] 93 | ) 94 | ) 95 | 96 | def forward(self, x): 97 | depth = 0 98 | for attn, ff in self.layers: 99 | grad_x = attn(x) + x 100 | 101 | x = ff(grad_x) 102 | return x 103 | 104 | 105 | class CRATE(nn.Module): 106 | def __init__( 107 | self, *, image_size, patch_size, num_classes, dim, depth, heads, pool='cls', channels=3, dim_head=64, 108 | dropout=0., emb_dropout=0., ista=0.1 109 | ): 110 | super().__init__() 111 | image_height, image_width = pair(image_size) 112 | patch_height, patch_width = pair(patch_size) 113 | 114 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 115 | 116 | num_patches = (image_height // patch_height) * (image_width // patch_width) 117 | patch_dim = channels * patch_height * patch_width 118 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 119 | 120 | self.to_patch_embedding = nn.Sequential( 121 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), 122 | nn.LayerNorm(patch_dim), 123 | nn.Linear(patch_dim, dim), 124 | nn.LayerNorm(dim), 125 | ) 126 | 127 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 128 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 129 | self.dropout = nn.Dropout(emb_dropout) 130 | 131 | self.transformer = Transformer(dim, depth, heads, dim_head, dropout, ista=ista) 132 | 133 | self.pool = pool 134 | self.to_latent = nn.Identity() 135 | 136 | self.mlp_head = nn.Sequential( 137 | nn.LayerNorm(dim), 138 | nn.Linear(dim, num_classes) 139 | ) 140 | 141 | def forward(self, img): 142 | x = self.to_patch_embedding(img) 143 | b, n, _ = x.shape 144 | 145 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) 146 | x = torch.cat((cls_tokens, x), dim=1) 147 | x += self.pos_embedding[:, :(n + 1)] 148 | x = self.dropout(x) 149 | 150 | x = self.transformer(x) 151 | feature_pre = x 152 | x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] 153 | 154 | x = self.to_latent(x) 155 | feature_last = x 156 | return self.mlp_head(x) 157 | 158 | 159 | def CRATE_tiny(num_classes=1000): 160 | return CRATE( 161 | image_size=224, 162 | patch_size=16, 163 | num_classes=num_classes, 164 | dim=384, 165 | depth=12, 166 | heads=6, 167 | dropout=0.0, 168 | emb_dropout=0.0, 169 | dim_head=384 // 6 170 | ) 171 | 172 | 173 | def CRATE_small(num_classes=1000): 174 | return CRATE( 175 | image_size=224, 176 | patch_size=16, 177 | num_classes=num_classes, 178 | dim=576, 179 | depth=12, 180 | heads=12, 181 | dropout=0.0, 182 | emb_dropout=0.0, 183 | dim_head=576 // 12 184 | ) 185 | 186 | 187 | def CRATE_base(num_classes=1000): 188 | return CRATE( 189 | image_size=224, 190 | patch_size=16, 191 | num_classes=num_classes, 192 | dim=768, 193 | depth=12, 194 | heads=12, 195 | dropout=0.0, 196 | emb_dropout=0.0, 197 | dim_head=768 // 12 198 | ) 199 | 200 | 201 | def CRATE_large(num_classes=1000): 202 | return CRATE( 203 | image_size=224, 204 | patch_size=16, 205 | num_classes=num_classes, 206 | dim=1024, 207 | depth=24, 208 | heads=16, 209 | dropout=0.0, 210 | emb_dropout=0.0, 211 | dim_head=1024 // 16 212 | ) 213 | -------------------------------------------------------------------------------- /model/crate_ae/crate_ae.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.vision_transformer import PatchEmbed, Block as ViTBlock 6 | 7 | from model.crate_ae.crate_decoder import Block_CRATE as CRATEDecoderBlock 8 | from model.crate_ae.crate_encoder import Block_CRATE as CRATEEncoderBlock 9 | from model.crate_ae.pos_embed import get_2d_sincos_pos_embed 10 | import math 11 | 12 | class MaskedAutoencoderViT(nn.Module): 13 | """ Masked Autoencoder with VisionTransformer backbone 14 | """ 15 | 16 | def __init__( 17 | self, encoder_block, decoder_block, img_size=224, patch_size=16, in_chans=3, 18 | embed_dim=1024, depth=24, num_heads=16, 19 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 20 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, lambd = 0.5 21 | ): 22 | super().__init__() 23 | self.heads = num_heads 24 | self.depth = depth 25 | 26 | # -------------------------------------------------------------------------- 27 | # MAE encoder specifics 28 | self.patch_size = patch_size 29 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim, strict_img_size=False) 30 | num_patches = self.patch_embed.num_patches 31 | 32 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 33 | self.pos_embed = nn.Parameter( 34 | torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False 35 | ) # fixed sin-cos embedding 36 | try: 37 | self.blocks = nn.ModuleList( 38 | [ # activates for CRATE blocks 39 | encoder_block(dim=embed_dim, heads=num_heads, dim_head=embed_dim // num_heads, lambd = lambd) 40 | for i in range(depth)] 41 | ) 42 | except TypeError: 43 | self.blocks = nn.ModuleList( 44 | [ # activates for ViT blocks 45 | encoder_block(dim=embed_dim, num_heads=num_heads) 46 | for i in range(depth)] 47 | ) 48 | 49 | self.norm = norm_layer(embed_dim) 50 | # -------------------------------------------------------------------------- 51 | 52 | # -------------------------------------------------------------------------- 53 | # MAE decoder specifics 54 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 55 | 56 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 57 | 58 | self.decoder_pos_embed = nn.Parameter( 59 | torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False 60 | ) # fixed sin-cos embedding 61 | try: 62 | self.decoder_blocks = nn.ModuleList( 63 | [ 64 | decoder_block( 65 | dim=decoder_embed_dim, heads=decoder_num_heads, dim_head=decoder_embed_dim // decoder_num_heads 66 | ) 67 | for i in range(decoder_depth)] 68 | ) 69 | except TypeError: 70 | self.decoder_blocks = nn.ModuleList( 71 | [ 72 | decoder_block(dim=decoder_embed_dim, num_heads=decoder_num_heads) 73 | for i in range(decoder_depth)] 74 | ) 75 | 76 | self.decoder_norm = norm_layer(decoder_embed_dim) 77 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch 78 | # -------------------------------------------------------------------------- 79 | 80 | self.norm_pix_loss = norm_pix_loss 81 | 82 | self.initialize_weights() 83 | 84 | def initialize_weights(self): 85 | # initialization 86 | # initialize (and freeze) pos_embed by sin-cos embedding 87 | pos_embed = get_2d_sincos_pos_embed( 88 | self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), cls_token=True 89 | ) 90 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 91 | 92 | decoder_pos_embed = get_2d_sincos_pos_embed( 93 | self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), cls_token=True 94 | ) 95 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 96 | 97 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 98 | w = self.patch_embed.proj.weight.data 99 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 100 | 101 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 102 | torch.nn.init.normal_(self.cls_token, std=.02) 103 | # torch.nn.init.normal_(self.cls_token) 104 | torch.nn.init.normal_(self.mask_token, std=.02) 105 | 106 | # initialize nn.Linear and nn.LayerNorm 107 | self.apply(self._init_weights) 108 | 109 | def _init_weights(self, m): 110 | if isinstance(m, nn.Linear): 111 | # we use xavier_uniform following official JAX ViT: 112 | torch.nn.init.xavier_uniform_(m.weight) 113 | if isinstance(m, nn.Linear) and m.bias is not None: 114 | nn.init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.LayerNorm): 116 | nn.init.constant_(m.bias, 0) 117 | nn.init.constant_(m.weight, 1.0) 118 | 119 | def patchify(self, imgs): 120 | """ 121 | imgs: (N, 3, H, W) 122 | x: (N, L, patch_size**2 *3) 123 | """ 124 | p = self.patch_embed.patch_size[0] 125 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 126 | 127 | h = w = imgs.shape[2] // p 128 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 129 | x = torch.einsum('nchpwq->nhwpqc', x) 130 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) 131 | return x 132 | 133 | def unpatchify(self, x): 134 | """ 135 | x: (N, L, patch_size**2 *3) 136 | imgs: (N, 3, H, W) 137 | """ 138 | p = self.patch_embed.patch_size[0] 139 | h = w = int(x.shape[1] ** .5) 140 | assert h * w == x.shape[1] 141 | 142 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 143 | x = torch.einsum('nhwpqc->nchpwq', x) 144 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 145 | return imgs 146 | 147 | def random_masking(self, x, mask_ratio): 148 | """ 149 | Perform per-sample random masking by per-sample shuffling. 150 | Per-sample shuffling is done by argsort random noise. 151 | x: [N, L, D], sequence 152 | """ 153 | N, L, D = x.shape # batch, length, dim 154 | len_keep = int(L * (1 - mask_ratio)) 155 | 156 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 157 | 158 | # sort noise for each sample 159 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 160 | ids_restore = torch.argsort(ids_shuffle, dim=1) 161 | 162 | # keep the first subset 163 | ids_keep = ids_shuffle[:, :len_keep] 164 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 165 | 166 | # generate the binary mask: 0 is keep, 1 is remove 167 | mask = torch.ones([N, L], device=x.device) 168 | mask[:, :len_keep] = 0 169 | # unshuffle to get the binary mask 170 | mask = torch.gather(mask, dim=1, index=ids_restore) 171 | 172 | return x_masked, mask, ids_restore 173 | 174 | def forward_encoder(self, x, mask_ratio): 175 | # embed patches 176 | x = self.patch_embed(x) 177 | 178 | # masking: length -> length * mask_ratio 179 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 180 | 181 | # append cls token 182 | cls_token = self.cls_token 183 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 184 | x = torch.cat((cls_tokens, x), dim=1) 185 | 186 | # append mask tokens to sequence 187 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 188 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 189 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 190 | # append cls token back 191 | x = torch.cat([x[:, :1, :], x_], dim=1) 192 | 193 | # add pos embed (after putting mask tokens and input tokens together) 194 | x = x + self.decoder_pos_embed 195 | 196 | # apply Transformer blocks 197 | for blk in self.blocks: 198 | x = blk(x) 199 | x = self.norm(x) 200 | 201 | return x, mask, ids_restore 202 | 203 | def forward_decoder(self, x, ids_restore): 204 | # apply Transformer blocks 205 | for blk in self.decoder_blocks: 206 | x = blk(x) 207 | x = self.decoder_norm(x) 208 | 209 | # predictor projection 210 | x = self.decoder_pred(x) 211 | 212 | # remove cls token 213 | x = x[:, 1:, :] 214 | 215 | return x 216 | 217 | def forward_loss(self, imgs, latent, pred, mask): 218 | """ 219 | imgs: [N, 3, H, W] 220 | pred: [N, L, p*p*3] 221 | mask: [N, L], 0 is keep, 1 is remove, 222 | """ 223 | target = self.patchify(imgs) 224 | if self.norm_pix_loss: 225 | mean = target.mean(dim=-1, keepdim=True) 226 | var = target.var(dim=-1, keepdim=True) 227 | target = (target - mean) / (var + 1.e-6) ** .5 228 | 229 | loss = (pred - target) ** 2 230 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 231 | 232 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 233 | 234 | return loss 235 | 236 | 237 | ### functions for visualizing attention 238 | def interpolate_pos_encoding(self, x, w, h): 239 | npatch = x.shape[1] - 1 240 | N = self.decoder_pos_embed.shape[1] - 1 241 | if npatch == N and w == h: 242 | return self.decoder_pos_embed 243 | class_pos_embed = self.decoder_pos_embed[:, 0] 244 | patch_pos_embed = self.decoder_pos_embed[:, 1:] 245 | dim = x.shape[-1] 246 | w0 = w // self.patch_size 247 | h0 = h // self.patch_size 248 | # we add a small number to avoid floating point error in the interpolation 249 | # see discussion at https://github.com/facebookresearch/dino/issues/8 250 | w0, h0 = w0 + 0.1, h0 + 0.1 251 | patch_pos_embed = nn.functional.interpolate( 252 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 253 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 254 | mode='bicubic', 255 | ) 256 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 257 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 258 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 259 | 260 | def forward(self, imgs, mask_ratio=0.75): 261 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 262 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 263 | loss = self.forward_loss(imgs, latent, pred, mask) 264 | return loss, pred, mask 265 | 266 | def prepare_tokens(self, x): 267 | B, nc, w, h = x.shape 268 | x = self.patch_embed(x) # patch linear embedding 269 | 270 | # add the [CLS] token to the embed patch tokens 271 | cls_tokens = self.cls_token.expand(B, -1, -1) 272 | x = torch.cat((cls_tokens, x), dim=1) 273 | 274 | # add positional encoding to each token 275 | x = x + self.interpolate_pos_encoding(x, w, h) 276 | return x 277 | 278 | def get_selfattention_enc(self, x, layer=10): 279 | x = self.prepare_tokens(x) 280 | for i, blk in enumerate(self.blocks): 281 | if i < layer: 282 | x = blk(x) 283 | else: 284 | # return attention of the last block 285 | return blk(x, return_attention=True) 286 | 287 | def get_last_key_enc(self, x, layer=10): 288 | x = self.prepare_tokens(x) 289 | for i, blk in enumerate(self.blocks): 290 | if i < layer: 291 | x = blk(x) 292 | else: 293 | # return sharedqkv of the last block 294 | return blk(x, return_key=True) 295 | 296 | def mae_vit_base(**kwargs): 297 | model = MaskedAutoencoderViT( 298 | ViTBlock, ViTBlock, 299 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 300 | decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16, 301 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 302 | ) 303 | return model 304 | 305 | 306 | def mae_vit_large(**kwargs): 307 | model = MaskedAutoencoderViT( 308 | ViTBlock, ViTBlock, 309 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 310 | decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16, 311 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 312 | ) 313 | return model 314 | 315 | 316 | def mae_vit_huge(**kwargs): 317 | model = MaskedAutoencoderViT( 318 | ViTBlock, ViTBlock, 319 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 320 | decoder_embed_dim=1280, decoder_depth=8, decoder_num_heads=16, 321 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 322 | ) 323 | return model 324 | 325 | 326 | def mae_crate_tiny(**kwargs): 327 | model = MaskedAutoencoderViT( 328 | CRATEEncoderBlock, CRATEDecoderBlock, 329 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 330 | decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=6, 331 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 332 | ) 333 | return model 334 | 335 | 336 | def mae_crate_small(**kwargs): 337 | model = MaskedAutoencoderViT( 338 | CRATEEncoderBlock, CRATEDecoderBlock, 339 | patch_size=16, embed_dim=576, depth=12, num_heads=12, 340 | decoder_embed_dim=576, decoder_depth=12, decoder_num_heads=12, 341 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 342 | ) 343 | return model 344 | 345 | 346 | def mae_crate_base(**kwargs): 347 | model = MaskedAutoencoderViT( 348 | CRATEEncoderBlock, CRATEDecoderBlock, 349 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 350 | decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12, 351 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 352 | ) 353 | return model 354 | 355 | 356 | def mae_crate_large(**kwargs): 357 | model = MaskedAutoencoderViT( 358 | CRATEEncoderBlock, CRATEDecoderBlock, 359 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 360 | decoder_embed_dim=1024, decoder_depth=24, decoder_num_heads=16, 361 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 362 | ) 363 | return model 364 | -------------------------------------------------------------------------------- /model/crate_ae/crate_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn.init as init 4 | from einops import rearrange 5 | from torch import nn 6 | 7 | 8 | def pair(t): 9 | return t if isinstance(t, tuple) else (t, t) 10 | 11 | 12 | class PreNorm(nn.Module): 13 | def __init__(self, dim, fn): 14 | super().__init__() 15 | self.norm = nn.LayerNorm(dim) 16 | self.fn = fn 17 | 18 | def forward(self, x, **kwargs): 19 | return self.fn(self.norm(x), **kwargs) 20 | 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout=0., step_size=0.1): 24 | super().__init__() 25 | self.weight = nn.Parameter(torch.Tensor(dim, dim)) 26 | with torch.no_grad(): 27 | init.kaiming_uniform_(self.weight) 28 | self.step_size = step_size 29 | self.lambd = 5.0 30 | 31 | def forward(self, x): 32 | # A@x 33 | output = F.linear(x, self.weight, bias=None) 34 | return output 35 | 36 | 37 | class Attention(nn.Module): 38 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., use_softmax=True): 39 | super().__init__() 40 | inner_dim = dim_head * heads 41 | project_out = not (heads == 1 and dim_head == dim) 42 | 43 | self.heads = heads 44 | self.scale = dim_head ** -0.5 45 | 46 | self.attend = nn.Softmax(dim=-1) if use_softmax else nn.Identity() 47 | self.dropout = nn.Dropout(dropout) 48 | 49 | self.qkv = nn.Linear(dim, inner_dim, bias=False) 50 | 51 | self.to_out = nn.Sequential( 52 | nn.Linear(inner_dim, dim), 53 | nn.Dropout(dropout) 54 | ) if project_out else nn.Identity() 55 | 56 | def forward(self, x): 57 | w = rearrange(self.qkv(x), 'b n (h d) -> b h n d', h=self.heads) 58 | 59 | dots = torch.matmul(w, w.transpose(-1, -2)) * self.scale # (b h n n) 60 | 61 | attn = self.attend(dots) 62 | attn = self.dropout(attn) 63 | 64 | out = torch.matmul(attn, w) 65 | 66 | out = rearrange(out, 'b h n d -> b n (h d)') 67 | return self.to_out(out) 68 | 69 | 70 | class Block_CRATE(nn.Module): 71 | def __init__(self, dim, heads, dim_head, dropout=0., ista=0.1, use_softmax=True): 72 | super().__init__() 73 | self.layers = nn.ModuleList([]) 74 | self.heads = heads 75 | self.dim = dim 76 | self.layers.append( 77 | nn.ModuleList( 78 | [ 79 | PreNorm(dim, FeedForward(dim, dim, dropout=dropout, step_size=ista)), 80 | PreNorm( 81 | dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, use_softmax=use_softmax) 82 | ), 83 | ] 84 | ) 85 | ) 86 | 87 | def forward(self, x): 88 | for ff, attn in self.layers: 89 | grad_x = ff(x) 90 | x = grad_x - attn(grad_x) 91 | return x 92 | -------------------------------------------------------------------------------- /model/crate_ae/crate_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn.init as init 4 | from einops import rearrange 5 | from torch import nn 6 | 7 | 8 | def pair(t): 9 | return t if isinstance(t, tuple) else (t, t) 10 | 11 | 12 | class PreNorm(nn.Module): 13 | def __init__(self, dim, fn): 14 | super().__init__() 15 | self.norm = nn.LayerNorm(dim) 16 | self.fn = fn 17 | 18 | def forward(self, x, **kwargs): 19 | return self.fn(self.norm(x), **kwargs) 20 | 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout=0., step_size=0.1, lambd = 0.5): 24 | super().__init__() 25 | self.weight = nn.Parameter(torch.Tensor(dim, dim)) 26 | with torch.no_grad(): 27 | init.kaiming_uniform_(self.weight) 28 | self.step_size = step_size 29 | # self.lambd = 0.1 30 | # self.lambd = 0.5 31 | self.lambd = lambd 32 | 33 | def forward(self, x): 34 | # compute D^T * D * x 35 | x1 = F.linear(x, self.weight, bias=None) 36 | grad_1 = F.linear(x1, self.weight.t(), bias=None) 37 | # compute D^T * x 38 | grad_2 = F.linear(x, self.weight.t(), bias=None) 39 | # compute negative gradient update: step_size * (D^T * x - D^T * D * x) 40 | grad_update = self.step_size * (grad_2 - grad_1) - self.step_size * self.lambd 41 | 42 | output = F.relu(x + grad_update) 43 | return output 44 | 45 | 46 | class Attention(nn.Module): 47 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 48 | super().__init__() 49 | inner_dim = dim_head * heads 50 | project_out = not (heads == 1 and dim_head == dim) 51 | 52 | self.heads = heads 53 | self.scale = dim_head ** -0.5 54 | 55 | self.attend = nn.Softmax(dim=-1) 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | self.qkv = nn.Linear(dim, inner_dim, bias=False) 59 | 60 | self.to_out = nn.Sequential( 61 | nn.Linear(inner_dim, dim), 62 | nn.Dropout(dropout) 63 | ) if project_out else nn.Identity() 64 | 65 | def forward(self, x, return_attention=False, return_key = False): 66 | if return_key: 67 | return self.qkv(x) 68 | w = rearrange(self.qkv(x), 'b n (h d) -> b h n d', h = self.heads) 69 | 70 | dots = torch.matmul(w, w.transpose(-1, -2)) * self.scale 71 | 72 | attn = self.attend(dots) 73 | if return_attention: 74 | return attn 75 | attn = self.dropout(attn) 76 | 77 | out = torch.matmul(attn, w) 78 | 79 | out = rearrange(out, 'b h n d -> b n (h d)') 80 | return self.to_out(out) 81 | 82 | 83 | class Block_CRATE(nn.Module): 84 | def __init__(self, dim, heads, dim_head, dropout=0., ista=0.1, lambd = 0.5): 85 | super().__init__() 86 | self.layers = nn.ModuleList([]) 87 | self.heads = heads 88 | self.dim = dim 89 | self.layers.append( 90 | nn.ModuleList( 91 | [ 92 | PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), 93 | PreNorm(dim, FeedForward(dim, dim, dropout=dropout, step_size=ista, lambd=lambd)) 94 | ] 95 | ) 96 | ) 97 | 98 | def forward(self, x, return_attention=False, return_key=False): 99 | for attn, ff in self.layers: 100 | if return_attention: 101 | return attn(x, return_attention=True) 102 | if return_key: 103 | return attn(x, return_key=True) 104 | grad_x = attn(x) + x 105 | x = ff(grad_x) 106 | return x 107 | -------------------------------------------------------------------------------- /model/crate_ae/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | 15 | # -------------------------------------------------------- 16 | # 2D sine-cosine position embedding 17 | # References: 18 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: https://github.com/facebookresearch/moco-v3 20 | # -------------------------------------------------------- 21 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 22 | """ 23 | grid_size: int of the grid height and width 24 | return: 25 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 26 | """ 27 | grid_h = np.arange(grid_size, dtype=np.float32) 28 | grid_w = np.arange(grid_size, dtype=np.float32) 29 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 30 | grid = np.stack(grid, axis=0) 31 | 32 | grid = grid.reshape([2, 1, grid_size, grid_size]) 33 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 34 | if cls_token: 35 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 36 | return pos_embed 37 | 38 | 39 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 40 | assert embed_dim % 2 == 0 41 | 42 | # use half of dimensions to encode grid_h 43 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 44 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 45 | 46 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float32) 58 | omega /= embed_dim / 2. 59 | omega = 1. / 10000 ** omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | 70 | 71 | # -------------------------------------------------------- 72 | # Interpolate position embeddings for high-resolution 73 | # References: 74 | # DeiT: https://github.com/facebookresearch/deit 75 | # -------------------------------------------------------- 76 | def interpolate_pos_embed(model, checkpoint_model): 77 | if 'pos_embed' in checkpoint_model: 78 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 79 | embedding_size = pos_embed_checkpoint.shape[-1] 80 | num_patches = model.patch_embed.num_patches 81 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 82 | # height (== width) for the checkpoint position embedding 83 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 84 | # height (== width) for the new position embedding 85 | new_size = int(num_patches ** 0.5) 86 | # class_token and dist_token are kept unchanged 87 | if orig_size != new_size: 88 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 89 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 90 | # only the position tokens are interpolated 91 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 92 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 93 | pos_tokens = torch.nn.functional.interpolate( 94 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False 95 | ) 96 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 97 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 98 | checkpoint_model['pos_embed'] = new_pos_embed 99 | -------------------------------------------------------------------------------- /model/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class Identity(nn.Module): 21 | def __init__(self): 22 | super(Identity, self).__init__() 23 | 24 | def forward(self, x): 25 | return x 26 | 27 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 28 | """ Vision Transformer with support for global average pooling 29 | """ 30 | def __init__(self, global_pool=False, **kwargs): 31 | super(VisionTransformer, self).__init__(**kwargs) 32 | 33 | self.global_pool = global_pool 34 | if self.global_pool: 35 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 36 | # norm_layer = kwargs['norm_layer'] 37 | embed_dim = kwargs['embed_dim'] 38 | self.fc_norm = norm_layer(embed_dim) 39 | # self.fc_norm = Identity() 40 | 41 | del self.norm # remove the original norm 42 | 43 | def forward_head(self, x, pre_logits: bool = False): 44 | # if self.global_pool: 45 | # x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 46 | x = self.fc_norm(x) 47 | # x = self.head_drop(x) 48 | return self.head(x) 49 | 50 | def forward_features(self, x): 51 | B = x.shape[0] 52 | x = self.patch_embed(x) 53 | 54 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 55 | x = torch.cat((cls_tokens, x), dim=1) 56 | x = x + self.pos_embed 57 | x = self.pos_drop(x) 58 | 59 | for blk in self.blocks: 60 | x = blk(x) 61 | 62 | if self.global_pool: 63 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 64 | outcome = self.fc_norm(x) 65 | else: 66 | x = self.norm(x) 67 | outcome = x[:, 0] 68 | return outcome 69 | 70 | 71 | # def vit_nano_patch16(**kwargs): 72 | # model = VisionTransformer( 73 | # patch_size=16, embed_dim=192, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 74 | # norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 75 | # return model 76 | 77 | 78 | def vit_tiny_patch16(**kwargs): 79 | model = VisionTransformer( 80 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 82 | return model 83 | 84 | 85 | def vit_small_patch16(**kwargs): 86 | model = VisionTransformer( 87 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 88 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 89 | return model 90 | 91 | 92 | def vit_base_patch16(**kwargs): 93 | model = VisionTransformer( 94 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 95 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 96 | return model 97 | 98 | 99 | def vit_large_patch16(**kwargs): 100 | model = VisionTransformer( 101 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 102 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 103 | return model 104 | 105 | 106 | def vit_huge_patch14(**kwargs): 107 | model = VisionTransformer( 108 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 109 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 110 | return model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | torchvision 4 | einops 5 | argparse 6 | timm 7 | lion_pytorch -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | '''Some helper functions for PyTorch, including: 4 | - get_mean_and_std: calculate the mean and std value of dataset. 5 | - msr_init: net parameter initialization. 6 | - progress_bar: progress bar mimic xlua.progress. 7 | ''' 8 | import os 9 | import sys 10 | import time 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | 16 | 17 | def get_mean_and_std(dataset): 18 | '''Compute the mean and std value of dataset.''' 19 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 20 | mean = torch.zeros(3) 21 | std = torch.zeros(3) 22 | print('==> Computing mean and std..') 23 | for inputs, targets in dataloader: 24 | for i in range(3): 25 | mean[i] += inputs[:,i,:,:].mean() 26 | std[i] += inputs[:,i,:,:].std() 27 | mean.div_(len(dataset)) 28 | std.div_(len(dataset)) 29 | return mean, std 30 | 31 | def init_params(net): 32 | '''Init layer parameters.''' 33 | for m in net.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | init.kaiming_normal(m.weight, mode='fan_out') 36 | if m.bias: 37 | init.constant(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant(m.weight, 1) 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.Linear): 42 | init.normal(m.weight, std=1e-3) 43 | if m.bias: 44 | init.constant(m.bias, 0) 45 | 46 | 47 | try: 48 | _, term_width = os.popen('stty size', 'r').read().split() 49 | except: 50 | term_width = 80 51 | term_width = int(term_width) 52 | 53 | TOTAL_BAR_LENGTH = 65. 54 | last_time = time.time() 55 | begin_time = last_time 56 | def progress_bar(current, total, msg=None): 57 | global last_time, begin_time 58 | if current == 0: 59 | begin_time = time.time() # Reset for new bar. 60 | 61 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 62 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 63 | 64 | sys.stdout.write(' [') 65 | for i in range(cur_len): 66 | sys.stdout.write('=') 67 | sys.stdout.write('>') 68 | for i in range(rest_len): 69 | sys.stdout.write('.') 70 | sys.stdout.write(']') 71 | 72 | cur_time = time.time() 73 | step_time = cur_time - last_time 74 | last_time = cur_time 75 | tot_time = cur_time - begin_time 76 | 77 | L = [] 78 | L.append(' Step: %s' % format_time(step_time)) 79 | L.append(' | Tot: %s' % format_time(tot_time)) 80 | if msg: 81 | L.append(' | ' + msg) 82 | 83 | msg = ''.join(L) 84 | sys.stdout.write(msg) 85 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 86 | sys.stdout.write(' ') 87 | 88 | # Go back to the center of the bar. 89 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 90 | sys.stdout.write('\b') 91 | sys.stdout.write(' %d/%d ' % (current+1, total)) 92 | 93 | if current < total-1: 94 | sys.stdout.write('\r') 95 | else: 96 | sys.stdout.write('\n') 97 | sys.stdout.flush() 98 | 99 | def format_time(seconds): 100 | days = int(seconds / 3600/24) 101 | seconds = seconds - days*3600*24 102 | hours = int(seconds / 3600) 103 | seconds = seconds - hours*3600 104 | minutes = int(seconds / 60) 105 | seconds = seconds - minutes*60 106 | secondsf = int(seconds) 107 | seconds = seconds - secondsf 108 | millis = int(seconds*1000) 109 | 110 | f = '' 111 | i = 1 112 | if days > 0: 113 | f += str(days) + 'D' 114 | i += 1 115 | if hours > 0 and i <= 2: 116 | f += str(hours) + 'h' 117 | i += 1 118 | if minutes > 0 and i <= 2: 119 | f += str(minutes) + 'm' 120 | i += 1 121 | if secondsf > 0 and i <= 2: 122 | f += str(secondsf) + 's' 123 | i += 1 124 | if millis > 0 and i <= 2: 125 | f += str(millis) + 'ms' 126 | i += 1 127 | if f == '': 128 | f = '0ms' 129 | return f -------------------------------------------------------------------------------- /vis_utils/coding_rate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from einops import rearrange 4 | class CodingRate(nn.Module): 5 | def __init__(self, eps=0.01): 6 | super(CodingRate, self).__init__() 7 | self.eps = eps 8 | 9 | def forward(self, X): 10 | #normalize over the dim_heads dimension 11 | ''' 12 | X with shape (b, h, m, d) 13 | W with shape (b*h, d, m) 14 | I with shape (m, m) 15 | logdet2 with shape (b*h) 16 | ''' 17 | b, h, _, _ = X.shape 18 | # X = rearrange(X, 'b h m d -> (b h) m d') 19 | X = X/torch.norm(X, dim=-1, keepdim=True) 20 | # print((X @ X.transpose(1,2))[0]) 21 | W = X.transpose(-1,-2) 22 | 23 | 24 | _,_, p, m = W.shape 25 | I = torch.eye(m,device=W.device) 26 | scalar = p / (m * self.eps) 27 | 28 | product = W.transpose(-1,-2) @ W 29 | logdet2 = torch.logdet(I + scalar * product) 30 | # print(logdet2.shape) 31 | mcr2s = logdet2.sum(dim=-1)/(2.) 32 | # print(mcr2s.shape) 33 | mean_mcr2 = mcr2s.mean() 34 | stdev = mcr2s.std() 35 | return (mean_mcr2, stdev) -------------------------------------------------------------------------------- /vis_utils/crate_hook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import model.crate as crate 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from vis_utils.coding_rate import CodingRate 6 | from einops import rearrange, repeat 7 | from vis_utils.plot import * 8 | import argparse 9 | 10 | 11 | coding_rate_list = [] 12 | sparsity_list = [] 13 | def forward_hook_codingrate(module, input, output): 14 | coding_rate_list.append(criterion(rearrange(output, 'b n (h d) -> b h n d', h=model.transformer.heads))) 15 | 16 | 17 | def forward_hook_sparsity(module, input, output): 18 | sparsity_list.append(cal_sparsity(output.cpu().numpy(), is_sparse=True)) 19 | 20 | 21 | if __name__=="__main__": 22 | args = argparse.ArgumentParser() 23 | args.add_argument("--checkpoint_path", type=str, default="checkpoint.pth.tar") 24 | args.add_argument("--input_path", type=str, default="firstbatch.pt") 25 | args = args.parse_args() 26 | """ 27 | Remark: The input_path contains augmented image tensors, with expected fprmat: 28 | {'imgs': torch.Size([batch_size, 3, 224, 224])} 29 | 30 | For your convenience, we provide an example inputs from image net here: 31 | https://drive.google.com/file/d/1LTnbVy4HgfaEIGpGdWlsHCF_WAlNkEkH/view?usp=sharing 32 | """ 33 | 34 | criterion = CodingRate() 35 | model = crate.CRATE_small() # change this if you are not using CRATE_small 36 | input = torch.load(args.input_path) 37 | ckpt = torch.load(args.checkpoint_path, map_location='cpu') 38 | 39 | 40 | 41 | new_state_dict = {} 42 | for k, v in ckpt['state_dict'].items(): 43 | if k.startswith('module.'): 44 | k = k[7:] 45 | new_state_dict[k] = v 46 | 47 | model.load_state_dict(new_state_dict) 48 | model = model.cuda() 49 | model.eval() 50 | 51 | for layer in model.transformer.layers: 52 | # print(layer[0].fn.qkv) 53 | layer[0].fn.qkv.register_forward_hook(forward_hook_codingrate) 54 | layer[1].register_forward_hook(forward_hook_sparsity) 55 | with torch.no_grad(): 56 | output = model(input['imgs'].cuda()) 57 | 58 | means = [] 59 | std_devs = [] 60 | for (mean, std) in coding_rate_list: 61 | means.append(mean.item()) 62 | std_devs.append(std.item()) 63 | 64 | sparsities = [] 65 | std_sparsities = [] 66 | for (mean, std) in sparsity_list: 67 | sparsities.append(mean) 68 | std_sparsities.append(std) 69 | 70 | 71 | means = [means] 72 | std_devs = [std_devs] 73 | sparsities = [sparsities] 74 | std_sparsities = [std_sparsities] 75 | 76 | plot_coding_rate(means, std_devs) 77 | plot_sparsity(sparsities, std_sparsities) 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /vis_utils/pca_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code inspired by https://github.com/ShirAmir/dino-vit-features/blob/main/pca.py 3 | """ 4 | import argparse 5 | import PIL.Image 6 | import numpy 7 | import torch 8 | import torch.nn as nn 9 | from pathlib import Path 10 | from vis_utils.shared_extractor import CRATEExtractor 11 | from tqdm import tqdm 12 | import numpy as np 13 | from PIL import Image 14 | from sklearn.decomposition import PCA 15 | from typing import List, Tuple 16 | from scipy import ndimage 17 | 18 | @torch.no_grad() 19 | def pca(image_paths, load_size: int = 224, layer: int = 11, facet: str = 'key', bin: bool = False, stride: int = 4, 20 | model_type: str = 'dino_vits8', model: nn.Module = None, n_components: int = 4, 21 | all_together: bool = True, masks = None) -> List[Tuple[Image.Image, numpy.ndarray]]: 22 | """ 23 | finding pca of a set of images. 24 | :param image_paths: a list of paths of all the images. 25 | :param load_size: size of the smaller edge of loaded images. If None, does not resize. 26 | :param layer: layer to extract descriptors from. 27 | :param facet: facet to extract descriptors from. 28 | :param bin: if True use a log-binning descriptor. 29 | :param model_type: type of model to extract descriptors from. 30 | :param stride: stride of the model. 31 | :param n_components: number of pca components to produce. 32 | :param all_together: if true apply pca on all images together. 33 | :return: a list of lists containing an image and its principal components. 34 | """ 35 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 36 | extractor = CRATEExtractor(model_type, stride, device=device, model=model) 37 | descriptors_list = [] 38 | image_pil_list = [] 39 | num_patches_list = [] 40 | load_size_list = [] 41 | mask_indices_list = [] 42 | # extract descriptors and saliency maps for each image 43 | img_idx = 0 44 | for image_path in image_paths: 45 | image_batch, image_pil = extractor.preprocess(image_path, load_size) 46 | if masks is not None: 47 | mask = masks[img_idx] 48 | img_idx += 1 49 | #compute 1d mask indices 50 | mask_indices = np.where(mask.flatten())[0] 51 | # print("mask_indices", mask_indices) 52 | mask_indices_list.append(mask_indices) 53 | image_pil_list.append(image_pil) 54 | descs_all = extractor.extract_descriptors(image_batch.to(device), layer, facet, bin, include_cls=False).cpu().numpy() 55 | if masks is not None: 56 | descs = [descs_all[:, :, i, :] for i in mask_indices] 57 | descs = np.concatenate(descs, axis=2).reshape(1, 1, -1, descs_all.shape[-1]) 58 | else: 59 | descs = descs_all 60 | curr_num_patches, curr_load_size = extractor.num_patches, extractor.load_size 61 | num_patches_list.append(curr_num_patches) 62 | load_size_list.append(curr_load_size) 63 | descriptors_list.append(descs) 64 | if all_together: 65 | descriptors = np.concatenate(descriptors_list, axis=2)[0, 0] 66 | pca = PCA(n_components=n_components).fit(descriptors) 67 | pca_descriptors = pca.transform(descriptors) 68 | if masks is not None: 69 | split_idxs = np.array([len(indices) for indices in mask_indices_list]) 70 | else: 71 | split_idxs = np.array([num_patches[0] * num_patches[1] for num_patches in num_patches_list]) 72 | split_idxs = np.cumsum(split_idxs) 73 | pca_per_image_temp = np.split(pca_descriptors, split_idxs[:-1], axis=0) 74 | pca_per_image = [] 75 | if masks is not None: 76 | for i, pcas in enumerate(pca_per_image_temp): 77 | real_pcas = np.zeros((num_patches_list[i][0] * num_patches_list[i][1], n_components)) 78 | # print() 79 | real_pcas[mask_indices_list[i], :] = pcas 80 | # assign min value for masked indices 81 | real_pcas[real_pcas == 0] = real_pcas.min() - 1 82 | pca_per_image.append(real_pcas) 83 | else: 84 | pca_per_image = pca_per_image_temp 85 | else: 86 | pca_per_image = [] 87 | for descriptors in descriptors_list: 88 | pca = PCA(n_components=n_components).fit(descriptors[0, 0]) 89 | pca_descriptors = pca.transform(descriptors[0, 0]) 90 | pca_per_image.append(pca_descriptors) 91 | results = [(pil_image, img_pca.reshape((num_patches[0], num_patches[1], n_components))) for 92 | (pil_image, img_pca, num_patches) in zip(image_pil_list, pca_per_image, num_patches_list)] 93 | return results 94 | 95 | 96 | def plot_pca(pil_image: Image.Image, pca_image: numpy.ndarray, last_components_rgb: bool = True, 97 | save_resized=True,): 98 | """ 99 | finding pca of a set of images. 100 | :param pil_image: The original PIL image. 101 | :param pca_image: A numpy tensor containing pca components of the image. HxWxn_components 102 | :param save_dir: if None than show results. 103 | :param last_components_rgb: If true save last 3 components as RGB image in addition to each component separately. 104 | :param save_resized: If true save PCA components resized to original resolution. 105 | :param save_prefix: optional. prefix to saving 106 | :return: a list of lists containing an image and its principal components. 107 | """ 108 | # save_dir = Path(save_dir) 109 | # save_dir.mkdir(exist_ok=True, parents=True) 110 | # pil_image_path = save_dir / f'{save_prefix}_orig_img.png' 111 | # pil_image.save(pil_image_path) 112 | 113 | pca_images = [] 114 | n_components = pca_image.shape[2] 115 | for comp_idx in range(n_components): 116 | comp = pca_image[:, :, comp_idx] 117 | comp_min = comp.min(axis=(0, 1)) 118 | comp_max = comp.max(axis=(0, 1)) 119 | comp_img = (comp - comp_min) / (comp_max - comp_min) 120 | # comp_file_path = save_dir / f'{save_prefix}_{comp_idx}.png' 121 | pca_pil = Image.fromarray((comp_img * 255).astype(np.uint8)) 122 | if save_resized: 123 | pca_pil = pca_pil.resize(pil_image.size, resample=PIL.Image.NEAREST) 124 | # pca_pil.save(comp_file_path) 125 | pca_images.append(pca_pil) 126 | 127 | if last_components_rgb: 128 | # comp_idxs = f"{n_components-3}_{n_components-2}_{n_components-1}" 129 | comp = pca_image[:, :, -3:] 130 | comp_min = comp.min(axis=(0, 1)) 131 | comp_max = comp.max(axis=(0, 1)) 132 | comp_img = (comp - comp_min) / (comp_max - comp_min) 133 | # comp_file_path = save_dir / f'{save_prefix}_{comp_idxs}_rgb.png' 134 | pca_pil = Image.fromarray((comp_img * 255).astype(np.uint8)) 135 | if save_resized: 136 | pca_pil = pca_pil.resize(pil_image.size, resample=PIL.Image.NEAREST) 137 | # pca_pil.save(comp_file_path) 138 | pca_images.append(pca_pil) 139 | return pil_image, pca_images 140 | 141 | 142 | 143 | def plot_pca_mask(pil_image: Image.Image, pca_image: numpy.ndarray, save_resized=True, th = 0.5, inv = False): 144 | """ 145 | finding pca of a set of images. 146 | :param pil_image: The original PIL image. 147 | :param pca_image: A numpy tensor containing pca components of the image. HxWxn_components 148 | :param save_dir: if None than show results. 149 | :param last_components_rgb: If true save last 3 components as RGB image in addition to each component separately. 150 | :param save_resized: If true save PCA components resized to original resolution. 151 | :param save_prefix: optional. prefix to saving 152 | :return: a list of lists containing an image and its principal components. 153 | """ 154 | # save_dir = Path(save_dir) 155 | # save_dir.mkdir(exist_ok=True, parents=True) 156 | # pil_image_path = save_dir / f'{save_prefix}_orig_img.png' 157 | # pil_image.save(pil_image_path) 158 | 159 | n_components = pca_image.shape[2] 160 | # for threshold in [0.15, 0.2, 0.3, 0.4, 0.5, 0.6]: 161 | for threshold in [th]: 162 | for comp_idx in range(n_components): 163 | comp = pca_image[:, :, comp_idx] 164 | comp_min = comp.min(axis=(0, 1)) 165 | comp_max = comp.max(axis=(0, 1)) 166 | comp_img = (comp - comp_min) / (comp_max - comp_min) 167 | 168 | comp_img = comp_img > threshold if not inv else comp_img < threshold 169 | # pseudo_mask = densecrf(np.array(I_new), bipartition) 170 | 171 | comp_img = ndimage.binary_fill_holes(comp_img>=0.5) 172 | # print(comp_img.shape) 173 | # comp_file_path = save_dir / f'{save_prefix}_{comp_idx}_{threshold}.png' 174 | pca_pil = Image.fromarray((comp_img * 255).astype(np.uint8)) 175 | if save_resized: 176 | pca_pil = pca_pil.resize(pil_image.size, resample=PIL.Image.NEAREST) 177 | # return to tensor 178 | # comp_img = np.array(pca_pil) 179 | # pca_pil.save(comp_file_path) 180 | 181 | return comp_img 182 | 183 | """ taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse""" 184 | def str2bool(v): 185 | if isinstance(v, bool): 186 | return v 187 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 188 | return True 189 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 190 | return False 191 | else: 192 | raise argparse.ArgumentTypeError('Boolean value expected.') -------------------------------------------------------------------------------- /vis_utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def cal_sparsity(matrix, is_sparse=False): 5 | absmatrix = np.abs(matrix) 6 | #matrix have shape [batch_size, num_patches, dim] 7 | if is_sparse==True: 8 | sparsity_list = [np.count_nonzero(absmatrix[i,:,:]==0)/(matrix.shape[1]*matrix.shape[2]) for i in range(matrix.shape[0])] 9 | sparsity = np.mean(sparsity_list) 10 | stdev = np.std(sparsity_list) 11 | else: 12 | sparsity = None 13 | stdev = None 14 | 15 | return sparsity, stdev 16 | 17 | 18 | def plot_sparsity(sparsities, std_sparsities): 19 | fontsize=20 20 | plt.rcParams["figure.figsize"] = (10, 6) 21 | cmap = plt.get_cmap('plasma') 22 | i=0 23 | cmap_val = [0.3, 0.5] 24 | epochs = ["val"] 25 | for mean_sparsity in sparsities: 26 | std_sparsity = std_sparsities[i] 27 | 28 | mean_sparsity = [1 - i for i in mean_sparsity] 29 | x_labels = np.arange(len(mean_sparsity)) + 1 30 | if i == 0: 31 | plt.plot(x_labels, mean_sparsity, marker='s', alpha=0.9, markersize=8, linewidth=2.5, 32 | markeredgecolor='black', markeredgewidth=1.0, color='C1', label=f"{epochs[i]}" 33 | ) 34 | plt.errorbar(x=x_labels, y=mean_sparsity, yerr=std_sparsity, fmt='none', 35 | ecolor='C1', alpha=0.7, 36 | capsize=5, capthick=2.0, elinewidth=2.5, zorder=0) 37 | else: 38 | plt.plot(x_labels, mean_sparsity, marker='s', alpha=0.6, markersize=8, linewidth=2.5, 39 | markeredgecolor='black', markeredgewidth=1.0, color='C1', label=f"{epochs[i]}", linestyle='--' 40 | ) 41 | plt.errorbar(x=x_labels, y=mean_sparsity, yerr=std_sparsity, fmt='none', 42 | ecolor='C1', alpha=0.4, 43 | capsize=5, capthick=2.0, elinewidth=2.5, zorder=0, linestyle='--') 44 | i+=1 45 | plt.title('Measure output sparsity across layers', fontdict={'fontsize': fontsize}) 46 | plt.ylabel(r"Sparsity [ISTA block]", fontdict={'fontsize': fontsize}) 47 | plt.xlabel(r"Layer index - $\ell$", fontdict={'fontsize': fontsize}) 48 | # plt.xticks(x_labels, [f"{i + 1}" for i in range(len(mean))]) 49 | plt.grid(linestyle='--', color='gray') 50 | plt.legend(fontsize=fontsize, loc='lower left') 51 | plt.savefig(f"sparsity.pdf", format='pdf', dpi=600) 52 | plt.close() 53 | 54 | def plot_coding_rate(means, std_devs): 55 | fontsize=20 56 | plt.rcParams["figure.figsize"] = (10, 6) 57 | x_labels = np.arange(len(means[0])) + 1 58 | cmap = plt.get_cmap('viridis') 59 | i=0 60 | cmap_val = [0.3, 0.5, 0.9] 61 | epochs = ["val"] 62 | for mean_mcr2 in means: 63 | #set color 64 | std_mcr2 = std_devs[i] 65 | if i == 0: 66 | plt.plot(x_labels, mean_mcr2, marker='o', alpha=0.9, markersize=10, linewidth=2.5, 67 | markeredgecolor='black', markeredgewidth=1.0, color='C0', label=f"{epochs[i]}" 68 | ) 69 | plt.errorbar(x=x_labels, y=mean_mcr2, yerr=std_mcr2, fmt='none', 70 | ecolor='C0', alpha=0.7, 71 | capsize=5, capthick=2.0, elinewidth=2.5, zorder=0) 72 | if i==1: 73 | plt.plot(x_labels, mean_mcr2, marker='o', alpha=0.6, markersize=10, linewidth=2.5, 74 | markeredgecolor='black', markeredgewidth=1.0, color='C0', label=f"{epochs[i]}", linestyle='--' 75 | ) 76 | plt.errorbar(x=x_labels, y=mean_mcr2, yerr=std_mcr2, fmt='none', 77 | ecolor='C0', alpha=0.4, 78 | capsize=5, capthick=2.0, elinewidth=2.5, zorder=0, linestyle='--') 79 | i+=1 80 | plt.legend(fontsize=fontsize, loc='lower left') 81 | plt.title('Measure coding rate across layers', fontdict={'fontsize': fontsize}) 82 | plt.ylabel(r"$R^c(Z^{\ell})$ [SSA block]", fontdict={'fontsize': fontsize}) 83 | plt.xlabel(r"Layer index - $\ell$", fontdict={'fontsize': fontsize}) 84 | plt.grid(linestyle='--', color='gray') 85 | plt.savefig(f"mcr2.pdf", format='pdf', dpi=600) 86 | plt.close() -------------------------------------------------------------------------------- /vis_utils/shared_extractor.py: -------------------------------------------------------------------------------- 1 | from model.crate_ae.crate_ae import mae_crate_base 2 | # from crate_ae_extractor import crate_base 3 | 4 | import math 5 | from typing import Union, List, Tuple 6 | import types 7 | import torch.nn.modules.utils as nn_utils 8 | import torch 9 | from PIL import Image 10 | from torchvision import transforms 11 | import torch.nn as nn 12 | from pathlib import Path 13 | 14 | class CRATEExtractor: 15 | def __init__(self, model_type: str = 'crate_mae_b16', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): 16 | 17 | self.model_type = model_type 18 | self.device = device 19 | if model is not None: 20 | self.model = model 21 | else: 22 | raise NotImplementedError 23 | 24 | self.model = CRATEExtractor.patch_vit_resolution(self.model, stride=stride, model_type = model_type) 25 | self.model.eval() 26 | self.model.to(self.device) 27 | if model_type == 'crate_mae_b16': 28 | self.p = self.model.patch_embed.patch_size[0] 29 | self.stride = self.model.patch_embed.proj.stride 30 | else: 31 | raise NotImplementedError 32 | self.mean = (0.485, 0.456, 0.406) 33 | self.std = (0.229, 0.224, 0.225) 34 | 35 | self._feats = [] 36 | self.hook_handlers = [] 37 | self.load_size = None 38 | self.num_patches = None 39 | 40 | @staticmethod 41 | def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): 42 | """ 43 | Creates a method for position encoding interpolation. 44 | :param patch_size: patch size of the model. 45 | :param stride_hw: A tuple containing the new height and width stride respectively. 46 | :return: the interpolation method 47 | """ 48 | def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: 49 | 50 | npatch = x.shape[1] - 1 51 | N = self.decoder_pos_embed.shape[1] - 1 52 | if npatch == N and w == h: 53 | return self.decoder_pos_embed 54 | class_pos_embed = self.decoder_pos_embed[:, 0] 55 | patch_pos_embed = self.decoder_pos_embed[:, 1:] 56 | 57 | dim = x.shape[-1] 58 | # dim = self.dim 59 | # compute number of tokens taking stride into account 60 | w0 = 1 + (w - patch_size) // stride_hw[1] 61 | h0 = 1 + (h - patch_size) // stride_hw[0] 62 | # print(w0, h0, npatch) 63 | assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and 64 | stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" 65 | # we add a small number to avoid floating point error in the interpolation 66 | # see discussion at https://github.com/facebookresearch/dino/issues/8 67 | w0, h0 = w0 + 0.1, h0 + 0.1 68 | # print("patch_pos_shape:", patch_pos_embed.shape) 69 | patch_pos_embed = nn.functional.interpolate( 70 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 71 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 72 | mode='bicubic', 73 | align_corners=False, recompute_scale_factor=False 74 | ) 75 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 76 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 77 | # print("patch_pos_shape:", torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).shape) 78 | # print(torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)) 79 | # exit() 80 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 81 | 82 | return interpolate_pos_encoding 83 | 84 | @staticmethod 85 | def patch_vit_resolution(model: nn.Module, stride: int, model_type: str) -> nn.Module: 86 | """ 87 | change resolution of model output by changing the stride of the patch extraction. 88 | :param model: the model to change resolution for. 89 | :param stride: the new stride parameter. 90 | :return: the adjusted model 91 | """ 92 | if model_type == 'crate_mae_b16': 93 | patch_size = model.patch_embed.patch_size[0] 94 | else: 95 | raise NotImplementedError 96 | # if stride == patch_size: # nothing to do 97 | # return model 98 | 99 | stride = nn_utils._pair(stride) 100 | # print(patch_size) 101 | # print(stride) 102 | assert all([(patch_size // s_) * s_ == patch_size for s_ in 103 | stride]), f'stride {stride} should divide patch_size {patch_size}' 104 | 105 | # fix the stride 106 | # model.patch_embed.proj.stride = stride 107 | # fix the positional encoding code 108 | model.interpolate_pos_encoding = types.MethodType(CRATEExtractor._fix_pos_enc(patch_size, stride), model) 109 | return model 110 | 111 | def preprocess(self, image_path: Union[str, Path], 112 | load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: 113 | """ 114 | Preprocesses an image before extraction. 115 | :param image_path: path to image to be extracted. 116 | :param load_size: optional. Size to resize image before the rest of preprocessing. 117 | :return: a tuple containing: 118 | (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. 119 | (2) the pil image in relevant dimensions 120 | """ 121 | pil_image = Image.open(image_path).convert('RGB') 122 | if load_size is not None: 123 | pil_image = transforms.Resize((load_size, load_size), interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) 124 | prep = transforms.Compose([ 125 | transforms.ToTensor(), 126 | transforms.Normalize(mean=self.mean, std=self.std) 127 | ]) 128 | prep_img = prep(pil_image)[None, ...] 129 | return prep_img, pil_image 130 | 131 | 132 | 133 | def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', 134 | bin: bool = False, include_cls: bool = False) -> torch.Tensor: 135 | 136 | B, C, H, W = batch.shape 137 | # print("batch shape:", batch.shape) 138 | self.load_size = (H, W) 139 | self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) 140 | 141 | w = self.num_patches[1] 142 | h = self.num_patches[0] 143 | # shaper = torch.zeros((1, w * h + 1, self.model.dim)) 144 | 145 | # self.model.interpolated_pos_embed= nn.Parameter(self.model.interpolate_pos_encoding(shaper, W, H)) 146 | qkv = self.model.get_last_key_enc(batch, layer = layer) 147 | qkv = qkv[None, :, :, :] 148 | # qkv = (qkv.reshape(bs, nb_token, 1, nb_head, -1).permute(2, 0, 3, 1, 4)) 149 | # print("step:", qkv.shape) 150 | # qkv = qkv[0] 151 | # print("qkv.shape", qkv.shape) 152 | # k = qkv.transpose(1,2).reshape(bs, nb_token, -1) 153 | # feats = k[:, 1:].transpose(1,2).reshape(bs, self.feat_dim, feat_h * feat_w) 154 | return qkv[:, :, 1:, :] --------------------------------------------------------------------------------