├── models ├── __init__.py ├── convit.py └── vit.py ├── utils ├── __init__.py ├── datautils │ ├── __init__.py │ └── core50 │ │ └── core50data.py ├── toolkit.py ├── data_manager.py └── data.py ├── overview.jpg ├── LICENSE ├── environment.yaml ├── main.py ├── README.md ├── networks.py └── trainer.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/datautils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iamwangyabin/ESN/HEAD/overview.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Wang Yabin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/toolkit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def count_parameters(model, trainable=False): 7 | if trainable: 8 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 9 | return sum(p.numel() for p in model.parameters()) 10 | 11 | 12 | def tensor2numpy(x): 13 | return x.cpu().data.numpy() if x.is_cuda else x.data.numpy() 14 | 15 | 16 | def target2onehot(targets, n_classes): 17 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) 18 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.) 19 | return onehot 20 | 21 | 22 | def makedirs(path): 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | 26 | 27 | def accuracy(y_pred, y_true, nb_old, increment=10): 28 | assert len(y_pred) == len(y_true), 'Data length error.' 29 | all_acc = {} 30 | all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2) 31 | 32 | # Grouped accuracy 33 | for class_id in range(0, np.max(y_true), increment): 34 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0] 35 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0')) 36 | all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) 37 | 38 | # Old accuracy 39 | idxes = np.where(y_true < nb_old)[0] 40 | all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), 41 | decimals=2) 42 | 43 | # New accuracy 44 | idxes = np.where(y_true >= nb_old)[0] 45 | all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2) 46 | 47 | return all_acc 48 | 49 | 50 | def split_images_labels(imgs): 51 | # split trainset.imgs in ImageFolder 52 | images = [] 53 | labels = [] 54 | for item in imgs: 55 | images.append(item[0]) 56 | labels.append(item[1]) 57 | 58 | return np.array(images), np.array(labels) 59 | 60 | 61 | 62 | 63 | def accuracy_binary(y_pred, y_true, nb_old, increment=2): 64 | assert len(y_pred) == len(y_true), 'Data length error.' 65 | all_acc = {} 66 | all_acc['total'] = np.around((y_pred%2 == y_true%2).sum()*100 / len(y_true), decimals=2) 67 | 68 | # Grouped accuracy 69 | for class_id in range(0, np.max(y_true), increment): 70 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0] 71 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0')) 72 | all_acc[label] = np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes), decimals=2) 73 | 74 | # Old accuracy 75 | idxes = np.where(y_true < nb_old)[0] 76 | # all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes),decimals=2) 77 | all_acc['old'] = 0 if len(idxes) == 0 else np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes),decimals=2) 78 | 79 | # New accuracy 80 | idxes = np.where(y_true >= nb_old)[0] 81 | all_acc['new'] = np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes), decimals=2) 82 | 83 | return all_acc 84 | 85 | 86 | 87 | def accuracy_domain(y_pred, y_true, nb_old, increment=10): 88 | assert len(y_pred) == len(y_true), 'Data length error.' 89 | all_acc = {} 90 | all_acc['total'] = np.around((y_pred%345 == y_true%345).sum()*100 / len(y_true), decimals=2) 91 | return all_acc -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: sp 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - brotlipy=0.7.0=py38h27cfd23_1003 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2022.10.11=h06a4308_0 12 | - certifi=2022.9.24=py38h06a4308_0 13 | - cffi=1.15.1=py38h74dc2b5_0 14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 15 | - cryptography=37.0.1=py38h9ce1e76_0 16 | - cudatoolkit=11.3.1=h2bc3f7f_2 17 | - ffmpeg=4.3=hf484d3e_0 18 | - freetype=2.11.0=h70c0345_0 19 | - giflib=5.2.1=h7b6447c_0 20 | - gmp=6.2.1=h295c915_3 21 | - gnutls=3.6.15=he1e5248_0 22 | - idna=3.3=pyhd3eb1b0_0 23 | - intel-openmp=2021.4.0=h06a4308_3561 24 | - jpeg=9e=h7f8727e_0 25 | - lame=3.100=h7b6447c_0 26 | - lcms2=2.12=h3be6417_0 27 | - ld_impl_linux-64=2.38=h1181459_1 28 | - lerc=3.0=h295c915_0 29 | - libdeflate=1.8=h7f8727e_5 30 | - libffi=3.3=he6710b0_2 31 | - libgcc-ng=11.2.0=h1234567_1 32 | - libgomp=11.2.0=h1234567_1 33 | - libiconv=1.16=h7f8727e_2 34 | - libidn2=2.3.2=h7f8727e_0 35 | - libpng=1.6.37=hbc83047_0 36 | - libstdcxx-ng=11.2.0=h1234567_1 37 | - libtasn1=4.16.0=h27cfd23_0 38 | - libtiff=4.4.0=hecacb30_0 39 | - libunistring=0.9.10=h27cfd23_0 40 | - libuv=1.40.0=h7b6447c_0 41 | - libwebp=1.2.2=h55f646e_0 42 | - libwebp-base=1.2.2=h7f8727e_0 43 | - lz4-c=1.9.3=h295c915_1 44 | - mkl=2021.4.0=h06a4308_640 45 | - mkl-service=2.4.0=py38h7f8727e_0 46 | - mkl_fft=1.3.1=py38hd3c417c_0 47 | - mkl_random=1.2.2=py38h51133e4_0 48 | - ncurses=6.3=h5eee18b_3 49 | - nettle=3.7.3=hbbd107a_1 50 | - numpy=1.23.1=py38h6c91a56_0 51 | - numpy-base=1.23.1=py38ha15fc14_0 52 | - openh264=2.1.1=h4ff587b_0 53 | - openssl=1.1.1q=h7f8727e_0 54 | - pillow=9.2.0=py38hace64e9_1 55 | - pip=22.1.2=py38h06a4308_0 56 | - pycparser=2.21=pyhd3eb1b0_0 57 | - pyopenssl=22.0.0=pyhd3eb1b0_0 58 | - pysocks=1.7.1=py38h06a4308_0 59 | - python=3.8.13=h12debd9_0 60 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0 61 | - pytorch-mutex=1.0=cuda 62 | - readline=8.1.2=h7f8727e_1 63 | - requests=2.28.1=py38h06a4308_0 64 | - setuptools=63.4.1=py38h06a4308_0 65 | - six=1.16.0=pyhd3eb1b0_1 66 | - sqlite=3.39.2=h5082296_0 67 | - tk=8.6.12=h1ccaba5_0 68 | - torchaudio=0.11.0=py38_cu113 69 | - torchvision=0.12.0=py38_cu113 70 | - typing_extensions=4.3.0=py38h06a4308_0 71 | - urllib3=1.26.11=py38h06a4308_0 72 | - wheel=0.37.1=pyhd3eb1b0_0 73 | - xz=5.2.5=h7f8727e_1 74 | - zlib=1.2.12=h5eee18b_3 75 | - zstd=1.5.2=ha4553b6_0 76 | - pip: 77 | - aiohttp==3.8.3 78 | - aiosignal==1.3.1 79 | - antlr4-python3-runtime==4.9.3 80 | - async-timeout==4.0.2 81 | - attrs==22.2.0 82 | - click==8.1.3 83 | - contourpy==1.0.5 84 | - cycler==0.11.0 85 | - docker-pycreds==0.4.0 86 | - einops==0.6.0 87 | - fonttools==4.37.3 88 | - frozenlist==1.3.3 89 | - fsspec==2022.11.0 90 | - ftfy==6.1.1 91 | - gitdb==4.0.9 92 | - gitpython==3.1.27 93 | - joblib==1.2.0 94 | - kiwisolver==1.4.4 95 | - lightning-utilities==0.5.0 96 | - matplotlib==3.6.0 97 | - multidict==6.0.4 98 | - omegaconf==2.3.0 99 | - opencv-python==4.6.0.66 100 | - packaging==21.3 101 | - pandas==1.5.0 102 | - pathtools==0.1.2 103 | - pot==0.8.2 104 | - promise==2.3 105 | - protobuf==3.20.1 106 | - psutil==5.9.2 107 | - pyparsing==3.0.9 108 | - python-dateutil==2.8.2 109 | - pytorch-lightning==1.8.6 110 | - pytz==2022.2.1 111 | - pyyaml==6.0 112 | - quadprog==0.1.11 113 | - regex==2022.9.13 114 | - scikit-learn==0.23.1 115 | - scipy==1.9.1 116 | - seaborn==0.12.0 117 | - sentry-sdk==1.9.8 118 | - setproctitle==1.3.2 119 | - shortuuid==1.0.9 120 | - sklearn==0.0 121 | - smmap==5.0.0 122 | - tensorboardx==2.5.1 123 | - threadpoolctl==3.1.0 124 | - timm==0.6.7 125 | - torchmetrics==0.11.0 126 | - tqdm==4.64.1 127 | - wandb==0.13.3 128 | - wcwidth==0.2.5 129 | - yarl==1.8.2 130 | prefix: /home/wangyabin/miniconda3/envs/sp 131 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import time 5 | import argparse 6 | import wandb 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from networks import incremental_vitood 12 | from trainer import training, eval 13 | from utils.data_manager import DataManager 14 | 15 | def setup_parser(): 16 | parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.') 17 | parser.add_argument('--method', default='ESN', type=str, help='str for comment') 18 | parser.add_argument('--model_name', default='vit', type=str, help='str for comment') 19 | parser.add_argument('--dataset', default='5datasets_vit', type=str, help='cifar100_vit, 5datasets_vit, core50') 20 | parser.add_argument('--init_cls', default=10, type=int, help='str for comment') 21 | parser.add_argument('--inc_cls', default=10, type=int, help='str for comment') 22 | parser.add_argument('--shuffle', action='store_false', help='false is l2p, which is not shuffle') 23 | parser.add_argument('--random_seed', default=1993, type=int, help='str for comment') 24 | parser.add_argument('--training_device', default="2", type=str, help='str for comment') 25 | parser.add_argument('--max_epochs', default=50, type=int, help='str for comment') 26 | parser.add_argument('--lr', default=0.01, type=float, help='Set learning rate') 27 | 28 | parser.add_argument('--using_prompt', default=True, type=bool, help='str for comment') 29 | parser.add_argument('--anchor_energy', default=-10, type=float, help='str for comment') 30 | parser.add_argument('--energy_beta', default=1, type=float, help='str for comment') 31 | parser.add_argument('--lamda', default=0.1, type=float, help='0 means do not use energy alignment') 32 | parser.add_argument('--temptures', default=20, type=int, help='max temperature') 33 | parser.add_argument('--voting', default=True, type=bool, help='wither or not to voting') 34 | 35 | parser.add_argument('--dil', default=False, type=bool, help='For domain incremental learning evaluation') 36 | parser.add_argument('--max_cls', default=2, type=int, help='For domain incremental learning evaluation') 37 | parser.add_argument('--notes', default='', type=str, help='str for comment') 38 | 39 | return parser 40 | 41 | def _set_random(): 42 | torch.manual_seed(1) 43 | torch.cuda.manual_seed(1) 44 | torch.cuda.manual_seed_all(1) 45 | torch.backends.cudnn.deterministic = True 46 | torch.backends.cudnn.benchmark = False 47 | 48 | def main(): 49 | args = setup_parser().parse_args() 50 | os.environ["CUDA_VISIBLE_DEVICES"] = args.training_device 51 | _set_random() 52 | args.localtime = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 53 | 54 | if args.dataset == "5datasets_vit" or args.dataset == "core50": 55 | args.shuffle=False 56 | data_manager = DataManager(args.dataset, args.shuffle, args.random_seed, args.init_cls, args.inc_cls, args=vars(args)) 57 | args.class_order = data_manager._class_order 58 | 59 | wandb.init(project="ESN", 60 | name='{}_{}_{}_{}_{}_'.format(args.method, args.model_name, args.dataset, args.init_cls, args.inc_cls) + args.localtime, 61 | save_code=True, group='{}_{}'.format(args.dataset, args.model_name), notes=args.notes, config=args) 62 | 63 | all_tabs, all_classifiers, all_tokens, accuracy_log, vitpromptlist= [], [], [], [], [] 64 | vitprompt = None 65 | _known_classes=0 66 | 67 | for taskid in range(data_manager.nb_tasks): 68 | print("current task: {}".format(taskid)) 69 | _total_classes = _known_classes + data_manager.get_task_size(taskid) 70 | current_data = np.arange(_known_classes, _total_classes) 71 | train_dataset = data_manager.get_dataset(current_data, source='train', mode='train') 72 | 73 | if args.dataset == "core50": 74 | test_dataset = data_manager.get_dataset(np.arange(0, data_manager.get_task_size(0)), source='test', mode='test') 75 | else: 76 | test_dataset = data_manager.get_dataset(current_data, source='test', mode='test') 77 | 78 | train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8, persistent_workers=True, pin_memory=True) 79 | test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8) 80 | 81 | model, eval_result = training(args, incremental_vitood, taskid, train_loader, test_loader, 82 | known_classes=_known_classes, vitprompt=vitprompt, numclass=data_manager.get_task_size(taskid)) 83 | 84 | all_tabs.append(copy.deepcopy(model.tabs).cpu()) 85 | all_classifiers.append(copy.deepcopy(model.classifiers).cpu()) 86 | all_tokens.append(copy.deepcopy(model.task_tokens).cpu()) 87 | vitprompt = copy.deepcopy(model.vitprompt).cpu() 88 | vitpromptlist.append(copy.deepcopy(model.vitprompt).cpu()) 89 | 90 | del model 91 | 92 | _known_classes = _total_classes 93 | 94 | assembles = {'all_tabs': all_tabs, 'all_classifiers': all_classifiers, 'all_tokens': all_tokens, 'vitpromptlist':vitpromptlist} 95 | torch.save(assembles, './checkpoints/'+wandb.run.name+'.pth') 96 | 97 | eval(args, './checkpoints/'+wandb.run.name+'.pth', data_manager) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference 2 | 3 |
4 | This is the official implementation of our AAAI 2023 paper "Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference". 5 | This paper focuses on the prevalent stage interference and stage performance imbalance of incremental learning. To avoid obvious stage learning bottlenecks, we propose a new incremental learning framework, which leverages a series of stage-isolated classifiers to perform the learning task at each stage, without interference from others. To be concrete, to aggregate multiple stage classifiers as a uniform one impartially, we first introduce a temperature-controlled energy metric for indicating the confidence score levels of the stage classifiers. We then propose an anchor-based energy self-normalization strategy to ensure the stage classifiers work at the same energy level. Finally, we design a voting-based inference augmentation strategy for robust inference. The proposed method is rehearsal-free and can work for almost all incremental learning scenarios. We evaluate the proposed method on four large datasets. Extensive results demonstrate the superiority of the proposed method in setting up new state-of-the-art overall performance. 6 |
7 | 8 | 9 | **Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference**
10 | Yabin Wang*, Zhiheng Ma*, Zhiwu Huang, Yaowei Wang, Zhou Su, Xiaopeng Hong. 2023 Proceedings of the AAAI Conference on Artificial Intelligence (AAAI 23).
11 | [[Paper]](https://arxiv.org/abs/2211.15969) 12 | 13 | ## Introduction 14 | 15 |
16 | 17 | 18 | 19 | In this paper, we propose anchor-based energy self-normalization for stage classifiers. 20 | The classifiers of the current and the previous stages, $f_{\eta_s}$ and $f_{\eta_{s-1}}$, are aligned sequentially by restricting their energies around the anchor 21 |
22 | 23 | ![overview.jpg](overview.jpg) 24 | 25 | 26 | 27 | ## Enviroment setup 28 | Create the virtual environment for ESN. 29 | ```python 30 | conda env create -f environment.yaml 31 | ``` 32 | After this, you will get a new environment **esn** that can conduct ESN experiments. 33 | Run `conda activate esn` to activate. 34 | 35 | Note that only NVIDIA GPUs are supported for now, and we use NVIDIA RTX 3090. 36 | 37 | ## Dataset preparation 38 | Please refer to the following links to download three standard incremental learning benchmark datasets. 39 | [CIFAR-100] Auto Download 40 | [CORe50](https://vlomonaco.github.io/core50/index.html#dataset) 41 | [DomainNet](http://ai.bu.edu/M3SDA/) 42 | [5-Datasets] 43 | 44 | Unzip the downloaded files, and you will get the following folders. 45 | 46 | ``` 47 | core50 48 | └── core50_128x128 49 | ├── labels.pkl 50 | ├── LUP.pkl 51 | ├── paths.pkl 52 | ├── s1 53 | ├── s2 54 | ├── s3 55 | ... 56 | ``` 57 | 58 | ``` 59 | domainnet 60 | ├── clipart 61 | │ ├── aircraft_carrier 62 | │ ├── airplane 63 | │ ... ... 64 | ├── clipart_test.txt 65 | ├── clipart_train.txt 66 | ├── infograph 67 | │ ├── aircraft_carrier 68 | │ ├── airplane 69 | │ ... ... 70 | ├── infograph_test.txt 71 | ├── infograph_train.txt 72 | ├── painting 73 | │ ├── aircraft_carrier 74 | │ ├── airplane 75 | │ ... ... 76 | ... ... 77 | ``` 78 | 79 | ## Training: 80 | 81 | Please change the `data_path` in `utils/data.py` to the locations of the datasets. 82 | 83 | ### Split-CIFAR100: 84 | ``` 85 | python main.py --dataset cifar100_vit --max_epochs 30 --init_cls 10 --inc_cls 10 --shuffle 86 | ``` 87 | 88 | ### Split-DomainNet: 89 | ``` 90 | python main.py --dataset domainnet --max_epochs 10 --init_cls 20 --inc_cls 20 --shuffle 91 | ``` 92 | 93 | ### CORe50: 94 | ``` 95 | python main.py --dataset core50 --max_epochs 10 --init_cls 50 --inc_cls 50 --dil True --max_cls 50 96 | ``` 97 | 98 | ### 5-Datasets: 99 | ``` 100 | python main.py --dataset 5datasets_vit --max_epochs 10 --init_cls 10 --inc_cls 10 101 | ``` 102 | 103 | ## Model Zoo: 104 | 105 | Pretrained models are available [here](https://drive.google.com/drive/folders/1D8qv1klFXePl-aQr5T8NVpoAf3nmxUld?usp=sharing). 106 | 107 | We assume the downloaded weights are located under the `checkpoints` directory. 108 | 109 | Otherwise, you may need to change the corresponding paths in the scripts. 110 | 111 | 112 | [comment]: <> (## Results) 113 | 114 | [comment]: <> (![results1.png](results1.png)) 115 | 116 | [comment]: <> (![results2.png](results2.png)) 117 | 118 | [comment]: <> (![results3.png](results3.png)) 119 | 120 | ## License 121 | 122 | Please check the MIT [license](./LICENSE) that is listed in this repository. 123 | 124 | ## Acknowledgments 125 | 126 | We thank the following repos providing helpful components/functions in our work. 127 | 128 | - [S-Prompts](https://github.com/iamwangyabin/S-Prompts) 129 | - [PyCIL](https://github.com/G-U-N/PyCIL) 130 | 131 | ## Citation 132 | 133 | If you use any content of this repo for your work, please cite the following bib entry: 134 | ``` 135 | @article{wang2022isolation, 136 | title={Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference}, 137 | author={Wang, Yabin and Ma, Zhiheng and Huang, Zhiwu and Wang, Yaowei and Su, Zhou and Hong, Xiaopeng}, 138 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 139 | year={2023} 140 | } 141 | ``` 142 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.nn import functional as F 7 | import pytorch_lightning as pl 8 | 9 | from models.vit import VisionTransformer, PatchEmbed, Block,resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn 10 | from models.convit import ClassAttention 11 | from models.convit import Block as ConBlock 12 | 13 | 14 | class ViT_KPrompts(VisionTransformer): 15 | def __init__( 16 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', 17 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, 18 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, 19 | embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): 20 | 21 | super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, 22 | embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size, 23 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values, 24 | embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn) 25 | 26 | def forward(self, x, instance_tokens=None, returnbeforepool=False, **kwargs): 27 | x = self.patch_embed(x) 28 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 29 | 30 | if instance_tokens is not None: 31 | instance_tokens = instance_tokens.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) 32 | 33 | x = x + self.pos_embed.to(x.dtype) 34 | if instance_tokens is not None: 35 | x = torch.cat([x[:,:1,:], instance_tokens, x[:,1:,:]], dim=1) 36 | x = self.pos_drop(x) 37 | x = self.blocks(x) 38 | if returnbeforepool == True: 39 | return x 40 | x = self.norm(x) 41 | if self.global_pool: 42 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 43 | x = self.fc_norm(x) 44 | return x 45 | 46 | def _create_vision_transformer(variant, pretrained=False, **kwargs): 47 | if kwargs.get('features_only', None): 48 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 49 | 50 | # NOTE this extra code to support handling of repr size for in21k pretrained models 51 | pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) 52 | default_num_classes = pretrained_cfg['num_classes'] 53 | num_classes = kwargs.get('num_classes', default_num_classes) 54 | repr_size = kwargs.pop('representation_size', None) 55 | if repr_size is not None and num_classes != default_num_classes: 56 | repr_size = None 57 | 58 | model = build_model_with_cfg( 59 | ViT_KPrompts, variant, pretrained, 60 | pretrained_cfg=pretrained_cfg, 61 | representation_size=repr_size, 62 | pretrained_filter_fn=checkpoint_filter_fn, 63 | pretrained_custom_load='npz' in pretrained_cfg['url'], 64 | **kwargs) 65 | return model 66 | 67 | 68 | 69 | class incremental_vitood(pl.LightningModule): 70 | def __init__(self, num_cls, lr, max_epoch, weight_decay, known_classes, freezep, using_prompt, anchor_energy=-10, 71 | lamda=0.1, energy_beta=1): 72 | super().__init__() 73 | self.save_hyperparameters() 74 | 75 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) 76 | self.image_encoder =_create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs) 77 | 78 | self.classifiers = nn.Linear(self.image_encoder.embed_dim, self.hparams.num_cls, bias=True) 79 | self.tabs = ConBlock(dim=self.image_encoder.embed_dim, num_heads=12, mlp_ratio=0.5, qkv_bias=True, 80 | qk_scale=None, drop=0.,attn_drop=0., norm_layer=nn.LayerNorm, attention_type=ClassAttention) 81 | self.task_tokens = copy.deepcopy(self.image_encoder.cls_token) 82 | self.vitprompt = nn.Linear(self.image_encoder.embed_dim, 100, bias=False) 83 | self.pre_vitprompt = None 84 | 85 | for name, param in self.image_encoder.named_parameters(): 86 | param.requires_grad_(False) 87 | 88 | if self.hparams.freezep: 89 | for name, param in self.vitprompt.named_parameters(): 90 | param.requires_grad_(False) 91 | 92 | def forward(self, image): 93 | if self.hparams.using_prompt: 94 | image_features = self.image_encoder(image, instance_tokens=self.vitprompt, returnbeforepool=True, ) 95 | else: 96 | image_features = self.image_encoder(image, returnbeforepool=True) 97 | 98 | B = image_features.shape[0] 99 | task_token = self.task_tokens.expand(B, -1, -1) 100 | task_token, attn, v = self.tabs(torch.cat((task_token, image_features), dim=1), mask_heads=None) 101 | logits = self.classifiers(task_token[:, 0]) 102 | 103 | return logits 104 | 105 | def configure_optimizers(self): 106 | optparams = filter(lambda p: p.requires_grad, self.parameters()) 107 | optimizer = optim.SGD(optparams, momentum=0.9,lr=self.hparams.lr,weight_decay=self.hparams.weight_decay) 108 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=self.hparams.max_epoch) 109 | return [optimizer], [scheduler] 110 | 111 | def _calculate_loss(self, batch, mode='train'): 112 | _, images, labels = batch 113 | labels = labels-self.hparams.known_classes 114 | 115 | if self.hparams.using_prompt: 116 | image_features = self.image_encoder(images, instance_tokens=self.vitprompt.weight, returnbeforepool=True, ) 117 | else: 118 | image_features = self.image_encoder(images, returnbeforepool=True) 119 | B = image_features.shape[0] 120 | task_token = self.task_tokens.expand(B, -1, -1) 121 | task_token, attn, v = self.tabs(torch.cat((task_token, image_features), dim=1), mask_heads=None) 122 | logits = self.classifiers(task_token[:, 0]) 123 | loss = F.cross_entropy(logits, labels) 124 | 125 | output_div_t = -1.0 * self.hparams.energy_beta * logits 126 | output_logsumexp = torch.logsumexp(output_div_t, dim=1, keepdim=False) 127 | free_energy = -1.0 * output_logsumexp / self.hparams.energy_beta 128 | align_loss = self.hparams.lamda * ((free_energy - self.hparams.anchor_energy) ** 2).mean() 129 | 130 | if self.pre_vitprompt is not None: 131 | pre_feature = self.image_encoder(images, instance_tokens=self.pre_vitprompt.weight, returnbeforepool=True, ) 132 | kdloss = nn.MSELoss()(pre_feature.detach(), image_features) 133 | else: 134 | kdloss = 0 135 | 136 | loss = loss+align_loss+kdloss 137 | 138 | acc = (logits.argmax(dim=-1) == labels).float().mean() 139 | self.log("%s_loss" % mode, loss) 140 | self.log("%s_acc" % mode, acc) 141 | return loss 142 | 143 | def validation_step(self, batch, batch_idx): 144 | loss = self._calculate_loss(batch, mode='val') 145 | 146 | def training_step(self, batch, batch_idx): 147 | loss = self._calculate_loss(batch, mode='train') 148 | return loss 149 | 150 | def val_step(self, batch, batch_idx): 151 | self._calculate_loss(batch, mode='val') 152 | 153 | def test_step(self, batch, batch_idx): 154 | self._calculate_loss(batch, mode='test') 155 | 156 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import wandb 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 8 | from networks import _create_vision_transformer 9 | 10 | def training(args, prototype_mode, taskid, train_loader, test_loader, known_classes=0, vitprompt=None, numclass=10): 11 | trainer = pl.Trainer(default_root_dir='./checkpoints/'+wandb.run.name, accelerator="gpu", devices=1, 12 | max_epochs=args.max_epochs, progress_bar_refresh_rate=1, 13 | callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="train_acc"), 14 | LearningRateMonitor("epoch"),]) 15 | 16 | if taskid == 0: 17 | model = prototype_mode(num_cls=numclass, lr=args.lr, max_epoch=args.max_epochs, weight_decay=0.0005, 18 | known_classes=known_classes, freezep=False, using_prompt=args.using_prompt, 19 | anchor_energy=args.anchor_energy, lamda=args.lamda, energy_beta=args.energy_beta) 20 | else: 21 | model = prototype_mode(num_cls=numclass, lr=args.lr, max_epoch=args.max_epochs, weight_decay=0.0005, 22 | known_classes=known_classes, freezep=True, using_prompt=args.using_prompt, 23 | anchor_energy=args.anchor_energy, lamda=args.lamda, energy_beta=args.energy_beta) 24 | model.vitprompt = vitprompt 25 | 26 | 27 | trainer.fit(model, train_loader) 28 | if args.dataset == "core50": 29 | val_result = [{'test_acc':0}] 30 | else: 31 | val_result = trainer.test(model, dataloaders=test_loader, verbose=False) 32 | print(val_result) 33 | return model, val_result 34 | 35 | 36 | def eval(args, load_path, datamanage): 37 | assembles = torch.load(load_path, map_location=torch.device('cpu')) 38 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) 39 | ptvit = _create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs) 40 | 41 | device = 'cuda:0' 42 | ptvit = ptvit.to(device) 43 | ptvit.eval() 44 | all_tabs = assembles['all_tabs'] 45 | all_classifiers = assembles['all_classifiers'] 46 | all_tokens = assembles['all_tokens'] 47 | vitpromptlist = assembles['vitpromptlist'] 48 | 49 | all_tabs = [i.to(device) for i in all_tabs] 50 | all_classifiers = [i.to(device) for i in all_classifiers] 51 | all_tokens = [i.to(device) for i in all_tokens] 52 | vitpromptlist = [i.to(device) for i in vitpromptlist] 53 | 54 | _known_classes=0 55 | # fast mode 56 | candidata_temperatures = [0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 57 | # slow mode 58 | # candidata_temperatures = [i/1000 for i in range(1,1000,1)] 59 | select_temperature = [0.001] # for first stage with no former data 60 | accuracy_table, accs = [], [] 61 | print("Testing...") 62 | for taskid in range(datamanage.nb_tasks): 63 | _total_classes = _known_classes + datamanage.get_task_size(taskid) 64 | test_till_now_dataset = datamanage.get_dataset(np.arange(0, _total_classes), source='test', mode='test') 65 | test_till_now_loader = DataLoader(test_till_now_dataset, batch_size=128, shuffle=False, num_workers=8) 66 | 67 | if taskid > 0: 68 | classifiers = all_classifiers[:taskid+1] 69 | task_tokens = all_tokens[:taskid+1] 70 | tabs = all_tabs[:taskid+1] 71 | current_dataset = datamanage.get_dataset(np.arange(_known_classes, _total_classes), source='train', mode='test') 72 | current_dataloader = DataLoader(current_dataset, batch_size=128, shuffle=False, num_workers=8) 73 | all_energies = {i: [] for i in candidata_temperatures} 74 | with torch.no_grad(): 75 | for _, (_, inputs, targets) in enumerate(current_dataloader): 76 | inputs = inputs.to(device) 77 | targets = targets.to(device) 78 | energys = {i: [] for i in candidata_temperatures} 79 | image_features = ptvit(inputs, instance_tokens=vitpromptlist[taskid].weight, returnbeforepool=True) 80 | 81 | B = image_features.shape[0] 82 | for idx, fc in enumerate(classifiers): 83 | task_token = task_tokens[idx].expand(B, -1, -1) 84 | task_token, attn, v = tabs[idx](torch.cat((task_token, image_features), dim=1), mask_heads=None) 85 | task_token = task_token[:, 0] 86 | logit = fc(task_token) 87 | 88 | for tem in candidata_temperatures: 89 | energys[tem].append(torch.logsumexp(logit / tem, axis=-1)) 90 | energys = {i: torch.stack(energys[i]).T for i in candidata_temperatures} 91 | for i in candidata_temperatures: 92 | all_energies[i].append(energys[i]) 93 | 94 | all_energies = {i: torch.cat(all_energies[i]) for i in candidata_temperatures} 95 | seperation_accuracy = [] 96 | for i in candidata_temperatures: 97 | seperation_accuracy.append((sum(all_energies[i].max(1)[1]==(taskid))/len(all_energies[i])).item()) 98 | select_temperature.append(candidata_temperatures[np.array(seperation_accuracy).argmax()]) 99 | 100 | set_select_temperature = list(set(select_temperature)) 101 | y_pred, y_true = [], [] 102 | for _, (_, inputs, targets) in enumerate(test_till_now_loader): 103 | inputs = inputs.to(device) 104 | targets = targets.to(device) 105 | candiatetask = {i: [] for i in set_select_temperature} 106 | seperatePreds = [] 107 | 108 | with torch.no_grad(): 109 | image_features = ptvit(inputs, instance_tokens=vitpromptlist[taskid].weight, returnbeforepool=True) 110 | B = image_features.shape[0] 111 | for idx, fc in enumerate(all_classifiers[:taskid+1]): 112 | task_token = all_tokens[:taskid+1][idx].expand(B, -1, -1) 113 | task_token, attn, v = all_tabs[:taskid+1][idx](torch.cat((task_token, image_features), dim=1), mask_heads=None) 114 | task_token = task_token[:, 0] 115 | logit = fc(task_token) 116 | for tem in set_select_temperature: 117 | candiatetask[tem].append(torch.logsumexp(logit / tem, axis=-1)) 118 | seperatePreds.append(logit.max(1)[1]+idx*logit.shape[1]) 119 | 120 | candiatetask = {i: torch.stack(candiatetask[i]).T for i in set_select_temperature} 121 | seperatePreds = torch.stack(seperatePreds).T 122 | 123 | pred = [] 124 | for tem in set_select_temperature: 125 | val, ind = candiatetask[tem].max(1) 126 | pred.append(ind) 127 | indexselection = torch.stack(pred, 1) 128 | selectid = torch.mode(indexselection, dim=1, keepdim=False)[0] 129 | outputs = [] 130 | for row, idx in enumerate(selectid): 131 | outputs.append(seperatePreds[row][idx]) 132 | outputs = torch.stack(outputs) 133 | y_pred.append(outputs.cpu().numpy()) 134 | y_true.append(targets.cpu().numpy()) 135 | 136 | y_pred = np.concatenate(y_pred) 137 | y_true = np.concatenate(y_true) 138 | if args.dil: 139 | accs.append(np.around((y_pred.T%args.max_cls == y_true%args.max_cls).sum() * 100 / len(y_true), decimals=2)) 140 | else: 141 | accs.append(np.around((y_pred.T == y_true).sum() * 100 / len(y_true), decimals=2)) 142 | 143 | tempacc = [] 144 | for class_id in range(0, np.max(y_true), _total_classes-_known_classes): 145 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + _total_classes-_known_classes))[0] 146 | tempacc.append(np.around((y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=3)) 147 | accuracy_table.append(tempacc) 148 | 149 | _known_classes = _total_classes 150 | 151 | 152 | np_acctable = np.zeros([taskid + 1, taskid + 1]) 153 | for idxx, line in enumerate(accuracy_table): 154 | idxy = len(line) 155 | np_acctable[idxx, :idxy] = np.array(line) 156 | # import pdb;pdb.set_trace() 157 | np_acctable = np_acctable.T 158 | print("Accuracy table:") 159 | print(np_acctable) 160 | print("Accuracy curve:") 161 | print(accs) 162 | print("FAA: {}".format(accs[-1])) 163 | 164 | forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, taskid])[:taskid]) 165 | print("FF: {}".format(forgetting)) 166 | -------------------------------------------------------------------------------- /utils/data_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000 7 | from utils.data import iCIFAR100_vit, iGanFake, iGanClass, i5Datasets_vit, iImageNetR, iCore50, iDomainnetCIL 8 | 9 | 10 | 11 | class DataManager(object): 12 | def __init__(self, dataset_name, shuffle, seed, init_cls, increment, args=None): 13 | self.args = args 14 | self.dataset_name = dataset_name 15 | self._setup_data(dataset_name, shuffle, seed) 16 | assert init_cls <= len(self._class_order), 'No enough classes.' 17 | self._increments = [init_cls] 18 | while sum(self._increments) + increment < len(self._class_order): 19 | self._increments.append(increment) 20 | offset = len(self._class_order) - sum(self._increments) 21 | if offset > 0: 22 | self._increments.append(offset) 23 | 24 | self.attack = [ 25 | transforms.ToTensor(), 26 | ] 27 | 28 | 29 | @property 30 | def nb_tasks(self): 31 | return len(self._increments) 32 | 33 | def get_task_size(self, task): 34 | return self._increments[task] 35 | 36 | def get_dataset(self, indices, source, mode, appendent=None, ret_data=False): 37 | if source == 'train': 38 | x, y = self._train_data, self._train_targets 39 | elif source == 'test': 40 | x, y = self._test_data, self._test_targets 41 | else: 42 | raise ValueError('Unknown data source {}.'.format(source)) 43 | 44 | if mode == 'train': 45 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) 46 | elif mode == 'flip': 47 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf]) 48 | elif mode == 'test': 49 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) 50 | elif mode == 'attack': 51 | trsf = transforms.Compose([*self._test_trsf, *self.attack]) 52 | 53 | else: 54 | raise ValueError('Unknown mode {}.'.format(mode)) 55 | 56 | data, targets = [], [] 57 | for idx in indices: 58 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) 59 | data.append(class_data) 60 | targets.append(class_targets) 61 | 62 | if appendent is not None and len(appendent) != 0: 63 | appendent_data, appendent_targets = appendent 64 | data.append(appendent_data) 65 | targets.append(appendent_targets) 66 | 67 | data, targets = np.concatenate(data), np.concatenate(targets) 68 | 69 | if ret_data: 70 | return data, targets, DummyDataset(data, targets, trsf, self.use_path) 71 | else: 72 | return DummyDataset(data, targets, trsf, self.use_path) 73 | 74 | def get_anchor_dataset(self, mode, appendent=None, ret_data=False): 75 | if mode == 'train': 76 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) 77 | elif mode == 'flip': 78 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf]) 79 | elif mode == 'test': 80 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) 81 | else: 82 | raise ValueError('Unknown mode {}.'.format(mode)) 83 | 84 | data, targets = [], [] 85 | if appendent is not None and len(appendent) != 0: 86 | appendent_data, appendent_targets = appendent 87 | data.append(appendent_data) 88 | targets.append(appendent_targets) 89 | 90 | data, targets = np.concatenate(data), np.concatenate(targets) 91 | 92 | if ret_data: 93 | return data, targets, DummyDataset(data, targets, trsf, self.use_path) 94 | else: 95 | return DummyDataset(data, targets, trsf, self.use_path) 96 | 97 | def get_dataset_with_split(self, indices, source, mode, appendent=None, val_samples_per_class=0): 98 | if source == 'train': 99 | x, y = self._train_data, self._train_targets 100 | elif source == 'test': 101 | x, y = self._test_data, self._test_targets 102 | else: 103 | raise ValueError('Unknown data source {}.'.format(source)) 104 | 105 | if mode == 'train': 106 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf]) 107 | elif mode == 'test': 108 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf]) 109 | else: 110 | raise ValueError('Unknown mode {}.'.format(mode)) 111 | 112 | train_data, train_targets = [], [] 113 | val_data, val_targets = [], [] 114 | for idx in indices: 115 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1) 116 | val_indx = np.random.choice(len(class_data), val_samples_per_class, replace=False) 117 | train_indx = list(set(np.arange(len(class_data))) - set(val_indx)) 118 | val_data.append(class_data[val_indx]) 119 | val_targets.append(class_targets[val_indx]) 120 | train_data.append(class_data[train_indx]) 121 | train_targets.append(class_targets[train_indx]) 122 | 123 | if appendent is not None: 124 | appendent_data, appendent_targets = appendent 125 | for idx in range(0, int(np.max(appendent_targets))+1): 126 | append_data, append_targets = self._select(appendent_data, appendent_targets, 127 | low_range=idx, high_range=idx+1) 128 | val_indx = np.random.choice(len(append_data), val_samples_per_class, replace=False) 129 | train_indx = list(set(np.arange(len(append_data))) - set(val_indx)) 130 | val_data.append(append_data[val_indx]) 131 | val_targets.append(append_targets[val_indx]) 132 | train_data.append(append_data[train_indx]) 133 | train_targets.append(append_targets[train_indx]) 134 | 135 | train_data, train_targets = np.concatenate(train_data), np.concatenate(train_targets) 136 | val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets) 137 | 138 | return DummyDataset(train_data, train_targets, trsf, self.use_path), \ 139 | DummyDataset(val_data, val_targets, trsf, self.use_path) 140 | 141 | def _setup_data(self, dataset_name, shuffle, seed): 142 | idata = _get_idata(dataset_name, self.args) 143 | idata.download_data() 144 | 145 | # Data 146 | self._train_data, self._train_targets = idata.train_data, idata.train_targets 147 | self._test_data, self._test_targets = idata.test_data, idata.test_targets 148 | self.use_path = idata.use_path 149 | 150 | # Transforms 151 | self._train_trsf = idata.train_trsf 152 | self._test_trsf = idata.test_trsf 153 | self._common_trsf = idata.common_trsf 154 | 155 | # Order 156 | order = [i for i in range(len(np.unique(self._train_targets)))] 157 | if shuffle: 158 | np.random.seed(seed) 159 | order = np.random.permutation(len(order)).tolist() 160 | else: 161 | order = idata.class_order 162 | self._class_order = order 163 | logging.info(self._class_order) 164 | 165 | # Map indices 166 | self._train_targets = _map_new_class_index(self._train_targets, self._class_order) 167 | self._test_targets = _map_new_class_index(self._test_targets, self._class_order) 168 | 169 | def _select(self, x, y, low_range, high_range): 170 | idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] 171 | return x[idxes], y[idxes] 172 | 173 | 174 | class DummyDataset(Dataset): 175 | def __init__(self, images, labels, trsf, use_path=False): 176 | assert len(images) == len(labels), 'Data size error!' 177 | self.images = images 178 | self.labels = labels 179 | self.trsf = trsf 180 | self.use_path = use_path 181 | 182 | def __len__(self): 183 | return len(self.images) 184 | 185 | def __getitem__(self, idx): 186 | if self.use_path: 187 | image = self.trsf(pil_loader(self.images[idx])) 188 | else: 189 | image = self.trsf(Image.fromarray(self.images[idx])) 190 | label = self.labels[idx] 191 | 192 | return idx, image, label 193 | 194 | 195 | def _map_new_class_index(y, order): 196 | return np.array(list(map(lambda x: order.index(x), y))) 197 | 198 | 199 | def _get_idata(dataset_name, args=None): 200 | name = dataset_name.lower() 201 | if name == 'cifar10': 202 | return iCIFAR10() 203 | elif name == 'cifar100': 204 | return iCIFAR100() 205 | elif name == 'imagenet1000': 206 | return iImageNet1000() 207 | elif name == "imagenet100": 208 | return iImageNet100() 209 | elif name == "cifar100_vit": 210 | return iCIFAR100_vit() 211 | elif name == "5datasets_vit": 212 | return i5Datasets_vit() 213 | elif name == "core50": 214 | return iCore50() 215 | elif name == "ganfake": 216 | return iGanFake(args) 217 | elif name == "imagenetr": 218 | return iImageNetR() 219 | elif name == "ganclass": 220 | return iGanClass(args) 221 | elif name == "domainnet": 222 | return iDomainnetCIL() 223 | else: 224 | raise NotImplementedError('Unknown dataset {}.'.format(dataset_name)) 225 | 226 | 227 | def pil_loader(path): 228 | ''' 229 | Ref: 230 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 231 | ''' 232 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 233 | with open(path, 'rb') as f: 234 | img = Image.open(f) 235 | return img.convert('RGB') 236 | 237 | 238 | def accimage_loader(path): 239 | ''' 240 | Ref: 241 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 242 | accimage is an accelerated Image loader and preprocessor leveraging Intel IPP. 243 | accimage is available on conda-forge. 244 | ''' 245 | import accimage 246 | try: 247 | return accimage.Image(path) 248 | except IOError: 249 | # Potentially a decoding problem, fall back to PIL.Image 250 | return pil_loader(path) 251 | 252 | 253 | def default_loader(path): 254 | ''' 255 | Ref: 256 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder 257 | ''' 258 | from torchvision import get_image_backend 259 | if get_image_backend() == 'accimage': 260 | return accimage_loader(path) 261 | else: 262 | return pil_loader(path) 263 | -------------------------------------------------------------------------------- /utils/datautils/core50/core50data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | ################################################################################ 5 | # Copyright (c) 2019. Vincenzo Lomonaco. All rights reserved. # 6 | # See the accompanying LICENSE file for terms. # 7 | # # 8 | # Date: 23-07-2019 # 9 | # Author: Vincenzo Lomonaco # 10 | # E-mail: vincenzo.lomonaco@unibo.it # 11 | # Website: vincenzolomonaco.com # 12 | ################################################################################ 13 | 14 | """ Data Loader for the CORe50 Dataset """ 15 | 16 | # Python 2-3 compatible 17 | from __future__ import print_function 18 | from __future__ import division 19 | from __future__ import absolute_import 20 | 21 | # other imports 22 | import numpy as np 23 | import pickle as pkl 24 | import os 25 | import logging 26 | from hashlib import md5 27 | from PIL import Image 28 | 29 | 30 | class CORE50(object): 31 | """ CORe50 Data Loader calss 32 | Args: 33 | root (string): Root directory of the dataset where ``core50_128x128``, 34 | ``paths.pkl``, ``LUP.pkl``, ``labels.pkl``, ``core50_imgs.npz`` 35 | live. For example ``~/data/core50``. 36 | preload (string, optional): If True data is pre-loaded with look-up 37 | tables. RAM usage may be high. 38 | scenario (string, optional): One of the three scenarios of the CORe50 39 | benchmark ``ni``, ``nc``, ``nic``, `nicv2_79`,``nicv2_196`` and 40 | ``nicv2_391``. 41 | train (bool, optional): If True, creates the dataset from the training 42 | set, otherwise creates from test set. 43 | cumul (bool, optional): If True the cumulative scenario is assumed, the 44 | incremental scenario otherwise. Practically speaking ``cumul=True`` 45 | means that for batch=i also batch=0,...i-1 will be added to the 46 | available training data. 47 | run (int, optional): One of the 10 runs (from 0 to 9) in which the 48 | training batch order is changed as in the official benchmark. 49 | start_batch (int, optional): One of the training incremental batches 50 | from 0 to max-batch - 1. Remember that for the ``ni``, ``nc`` and 51 | ``nic`` we have respectively 8, 9 and 79 incremental batches. If 52 | ``train=False`` this parameter will be ignored. 53 | """ 54 | 55 | nbatch = { 56 | 'ni': 8, 57 | 'nc': 9, 58 | 'nic': 79, 59 | 'nicv2_79': 79, 60 | 'nicv2_196': 196, 61 | 'nicv2_391': 391 62 | } 63 | 64 | def __init__(self, root='', preload=False, scenario='ni', cumul=False, 65 | run=0, start_batch=0): 66 | """" Initialize Object """ 67 | 68 | self.root = os.path.expanduser(root) 69 | self.preload = preload 70 | self.scenario = scenario 71 | self.cumul = cumul 72 | self.run = run 73 | self.batch = start_batch 74 | 75 | if preload: 76 | print("Loading data...") 77 | bin_path = os.path.join(root, 'core50_imgs.bin') 78 | if os.path.exists(bin_path): 79 | with open(bin_path, 'rb') as f: 80 | self.x = np.fromfile(f, dtype=np.uint8) \ 81 | .reshape(164866, 128, 128, 3) 82 | 83 | else: 84 | with open(os.path.join(root, 'core50_imgs.npz'), 'rb') as f: 85 | npzfile = np.load(f) 86 | self.x = npzfile['x'] 87 | print("Writing bin for fast reloading...") 88 | self.x.tofile(bin_path) 89 | 90 | print("Loading paths...") 91 | with open(os.path.join(root, 'paths.pkl'), 'rb') as f: 92 | self.paths = pkl.load(f) 93 | 94 | print("Loading LUP...") 95 | with open(os.path.join(root, 'LUP.pkl'), 'rb') as f: 96 | self.LUP = pkl.load(f) 97 | 98 | print("Loading labels...") 99 | with open(os.path.join(root, 'labels.pkl'), 'rb') as f: 100 | self.labels = pkl.load(f) 101 | 102 | def __iter__(self): 103 | return self 104 | 105 | def get_data_batchidx(self, idx): 106 | 107 | scen = self.scenario 108 | run = self.run 109 | batch = idx 110 | 111 | if self.batch == self.nbatch[scen]: 112 | raise StopIteration 113 | 114 | # Getting the right indexis 115 | if self.cumul: 116 | train_idx_list = [] 117 | for i in range(self.batch + 1): 118 | train_idx_list += self.LUP[scen][run][i] 119 | else: 120 | train_idx_list = self.LUP[scen][run][batch] 121 | 122 | # loading data 123 | if self.preload: 124 | train_x = np.take(self.x, train_idx_list, axis=0)\ 125 | .astype(np.float32) 126 | else: 127 | print("Loading data...") 128 | # Getting the actual paths 129 | train_paths = [] 130 | for idx in train_idx_list: 131 | train_paths.append(os.path.join(self.root, self.paths[idx])) 132 | # loading imgs 133 | train_x = self.get_batch_from_paths(train_paths).astype(np.float32) 134 | 135 | # In either case we have already loaded the y 136 | if self.cumul: 137 | train_y = [] 138 | for i in range(self.batch + 1): 139 | train_y += self.labels[scen][run][i] 140 | else: 141 | train_y = self.labels[scen][run][batch] 142 | 143 | train_y = np.asarray(train_y, dtype=np.int) 144 | 145 | return (train_x, train_y) 146 | 147 | def __next__(self): 148 | """ Next batch based on the object parameter which can be also changed 149 | from the previous iteration. """ 150 | 151 | scen = self.scenario 152 | run = self.run 153 | batch = self.batch 154 | 155 | if self.batch == self.nbatch[scen]: 156 | raise StopIteration 157 | 158 | # Getting the right indexis 159 | if self.cumul: 160 | train_idx_list = [] 161 | for i in range(self.batch + 1): 162 | train_idx_list += self.LUP[scen][run][i] 163 | else: 164 | train_idx_list = self.LUP[scen][run][batch] 165 | 166 | # loading data 167 | if self.preload: 168 | train_x = np.take(self.x, train_idx_list, axis=0)\ 169 | .astype(np.float32) 170 | else: 171 | print("Loading data...") 172 | # Getting the actual paths 173 | train_paths = [] 174 | for idx in train_idx_list: 175 | train_paths.append(os.path.join(self.root, self.paths[idx])) 176 | # loading imgs 177 | train_x = self.get_batch_from_paths(train_paths).astype(np.float32) 178 | 179 | # In either case we have already loaded the y 180 | if self.cumul: 181 | train_y = [] 182 | for i in range(self.batch + 1): 183 | train_y += self.labels[scen][run][i] 184 | else: 185 | train_y = self.labels[scen][run][batch] 186 | 187 | train_y = np.asarray(train_y, dtype=np.int) 188 | 189 | # Update state for next iter 190 | self.batch += 1 191 | 192 | return (train_x, train_y) 193 | 194 | def get_test_set(self): 195 | """ Return the test set (the same for each inc. batch). """ 196 | 197 | scen = self.scenario 198 | run = self.run 199 | 200 | test_idx_list = self.LUP[scen][run][-1] 201 | 202 | if self.preload: 203 | test_x = np.take(self.x, test_idx_list, axis=0).astype(np.float32) 204 | else: 205 | # test paths 206 | test_paths = [] 207 | for idx in test_idx_list: 208 | test_paths.append(os.path.join(self.root, self.paths[idx])) 209 | 210 | # test imgs 211 | test_x = self.get_batch_from_paths(test_paths).astype(np.float32) 212 | 213 | test_y = self.labels[scen][run][-1] 214 | test_y = np.asarray(test_y, dtype=np.int) 215 | 216 | return test_x, test_y 217 | 218 | next = __next__ # python2.x compatibility. 219 | 220 | @staticmethod 221 | def get_batch_from_paths(paths, compress=False, snap_dir='', 222 | on_the_fly=True, verbose=False): 223 | """ Given a number of abs. paths it returns the numpy array 224 | of all the images. """ 225 | 226 | # Getting root logger 227 | log = logging.getLogger('mylogger') 228 | 229 | # If we do not process data on the fly we check if the same train 230 | # filelist has been already processed and saved. If so, we load it 231 | # directly. In either case we end up returning x and y, as the full 232 | # training set and respective labels. 233 | num_imgs = len(paths) 234 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest() 235 | log.debug("Paths Hex: " + str(hexdigest)) 236 | loaded = False 237 | x = None 238 | file_path = None 239 | 240 | if compress: 241 | file_path = snap_dir + hexdigest + ".npz" 242 | if os.path.exists(file_path) and not on_the_fly: 243 | loaded = True 244 | with open(file_path, 'rb') as f: 245 | npzfile = np.load(f) 246 | x, y = npzfile['x'] 247 | else: 248 | x_file_path = snap_dir + hexdigest + "_x.bin" 249 | if os.path.exists(x_file_path) and not on_the_fly: 250 | loaded = True 251 | with open(x_file_path, 'rb') as f: 252 | x = np.fromfile(f, dtype=np.uint8) \ 253 | .reshape(num_imgs, 128, 128, 3) 254 | 255 | # Here we actually load the images. 256 | if not loaded: 257 | # Pre-allocate numpy arrays 258 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8) 259 | 260 | for i, path in enumerate(paths): 261 | if verbose: 262 | print("\r" + path + " processed: " + str(i + 1), end='') 263 | x[i] = np.array(Image.open(path)) 264 | 265 | if verbose: 266 | print() 267 | 268 | if not on_the_fly: 269 | # Then we save x 270 | if compress: 271 | with open(file_path, 'wb') as g: 272 | np.savez_compressed(g, x=x) 273 | else: 274 | x.tofile(snap_dir + hexdigest + "_x.bin") 275 | 276 | assert (x is not None), 'Problems loading data. x is None!' 277 | 278 | return x 279 | 280 | 281 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torchvision import datasets, transforms 4 | from utils.toolkit import split_images_labels 5 | from PIL import Image 6 | 7 | 8 | class iData(object): 9 | train_trsf = [] 10 | test_trsf = [] 11 | common_trsf = [] 12 | class_order = None 13 | 14 | 15 | class iCIFAR10(iData): 16 | use_path = False 17 | train_trsf = [ 18 | transforms.RandomCrop(32, padding=4), 19 | transforms.RandomHorizontalFlip(p=0.5), 20 | transforms.ColorJitter(brightness=63/255) 21 | ] 22 | test_trsf = [] 23 | common_trsf = [ 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)), 26 | ] 27 | 28 | class_order = np.arange(10).tolist() 29 | 30 | def download_data(self): 31 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True) 32 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True) 33 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 34 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 35 | 36 | class iCIFAR100(iData): 37 | use_path = False 38 | train_trsf = [ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ColorJitter(brightness=63/255) 42 | ] 43 | test_trsf = [] 44 | common_trsf = [ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)), 47 | ] 48 | 49 | class_order = np.arange(100).tolist() 50 | 51 | def download_data(self): 52 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True) 53 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True) 54 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 55 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 56 | 57 | class iImageNet1000(iData): 58 | use_path = True 59 | train_trsf = [ 60 | transforms.RandomResizedCrop(224), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ColorJitter(brightness=63/255) 63 | ] 64 | test_trsf = [ 65 | transforms.Resize(256), 66 | transforms.CenterCrop(224), 67 | ] 68 | common_trsf = [ 69 | transforms.ToTensor(), 70 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 71 | ] 72 | 73 | class_order = np.arange(1000).tolist() 74 | 75 | def download_data(self): 76 | # assert 0,"You should specify the folder of your dataset" 77 | train_dir = '/home/wangyabin/workspace/data/train' 78 | test_dir = '/home/wangyabin/workspace/data/val' 79 | 80 | train_dset = datasets.ImageFolder(train_dir) 81 | test_dset = datasets.ImageFolder(test_dir) 82 | 83 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 84 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 85 | 86 | class iImageNet100(iData): 87 | use_path = True 88 | train_trsf = [ 89 | transforms.RandomResizedCrop(224), 90 | transforms.RandomHorizontalFlip(), 91 | ] 92 | test_trsf = [ 93 | transforms.Resize(256), 94 | transforms.CenterCrop(224), 95 | ] 96 | common_trsf = [ 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 99 | ] 100 | 101 | class_order = np.arange(1000).tolist() 102 | 103 | def download_data(self): 104 | train_dir = '/home/wangyabin/workspace/datasets/imagenet100/train/' 105 | test_dir = '/home/wangyabin/workspace/datasets/imagenet100/val/' 106 | 107 | train_dset = datasets.ImageFolder(train_dir) 108 | test_dset = datasets.ImageFolder(test_dir) 109 | 110 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs) 111 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs) 112 | 113 | class iCore50(iData): 114 | 115 | use_path = False 116 | train_trsf = [ 117 | transforms.RandomResizedCrop(224), 118 | transforms.RandomHorizontalFlip(), 119 | ] 120 | test_trsf = [ 121 | transforms.Resize(256), 122 | transforms.CenterCrop(224), 123 | ] 124 | common_trsf = [ 125 | transforms.ToTensor(), 126 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 127 | ] 128 | 129 | class_order = np.arange(8*50).tolist() 130 | 131 | def download_data(self): 132 | from utils.datautils.core50.core50data import CORE50 133 | datagen = CORE50(root='/home/wangyabin/workspace/datasets/core50/data/core50_128x128', scenario="ni") 134 | 135 | dataset_list = [] 136 | for i, train_batch in enumerate(datagen): 137 | imglist, labellist = train_batch 138 | labellist += i*50 139 | imglist = imglist.astype(np.uint8) 140 | dataset_list.append([imglist, labellist]) 141 | train_x = np.concatenate(np.array(dataset_list)[:, 0]) 142 | train_y = np.concatenate(np.array(dataset_list)[:, 1]) 143 | self.train_data = train_x 144 | self.train_targets = train_y 145 | 146 | test_x, test_y = datagen.get_test_set() 147 | test_x = test_x.astype(np.uint8) 148 | self.test_data = test_x 149 | self.test_targets = test_y 150 | # import pdb;pdb.set_trace() 151 | 152 | 153 | 154 | class iDomainnetCIL(iData): 155 | use_path = True 156 | train_trsf = [ 157 | transforms.RandomResizedCrop(224), 158 | transforms.RandomHorizontalFlip(), 159 | ] 160 | test_trsf = [ 161 | transforms.Resize(256), 162 | transforms.CenterCrop(224), 163 | ] 164 | common_trsf = [ 165 | transforms.ToTensor(), 166 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 167 | ] 168 | 169 | class_order = np.arange(200).tolist() 170 | 171 | def download_data(self): 172 | rootdir = '/home/wangyabin/workspace/datasets/domainnet' 173 | 174 | train_txt = './utils/datautils/domainnet/train.txt' 175 | test_txt = './utils/datautils/domainnet/test.txt' 176 | 177 | train_images = [] 178 | train_labels = [] 179 | with open(train_txt, 'r') as dict_file: 180 | for line in dict_file: 181 | (value, key) = line.strip().split(' ') 182 | train_images.append(os.path.join(rootdir, value)) 183 | train_labels.append(int(key)) 184 | train_images = np.array(train_images) 185 | train_labels = np.array(train_labels) 186 | test_images = [] 187 | test_labels = [] 188 | with open(test_txt, 'r') as dict_file: 189 | for line in dict_file: 190 | (value, key) = line.strip().split(' ') 191 | test_images.append(os.path.join(rootdir, value)) 192 | test_labels.append(int(key)) 193 | test_images = np.array(test_images) 194 | test_labels = np.array(test_labels) 195 | 196 | self.train_data = train_images 197 | self.train_targets = train_labels 198 | self.test_data = test_images 199 | self.test_targets = test_labels 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | class iImageNetR(iData): 209 | use_path = True 210 | train_trsf = [ 211 | transforms.RandomResizedCrop(224), 212 | transforms.RandomHorizontalFlip(), 213 | ] 214 | test_trsf = [ 215 | transforms.Resize(256), 216 | transforms.CenterCrop(224), 217 | ] 218 | common_trsf = [ 219 | transforms.ToTensor(), 220 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 221 | ] 222 | 223 | class_order = np.arange(200).tolist() 224 | 225 | # first we need get the pre-processed txt files which containing the taining and text split of DualPrompt. 226 | # the origin data structure is tfds, but we extract its data to a txt file 227 | """ 228 | we use following code to extract split 229 | def get_stats(split_name, filepath): 230 | stats = [] 231 | ds = dataset_builder.as_dataset(split=split_name) 232 | label_list = [] 233 | for batch in ds: 234 | label_list.append(int(batch["label"])) 235 | label_list = list(set(label_list)) 236 | data_dict = {i:[] for i in label_list} 237 | for batch in ds: 238 | data_dict[int(batch["label"])].append(batch["file_name"].numpy()) 239 | print(len(label_list)) 240 | label_list.sort() 241 | with open(filepath, 'w') as f: 242 | for i in label_list: 243 | for line in data_dict[i]: 244 | f.write(str(IR_LABEL_MAP[i]) + "\t" + line.decode("utf-8") +"\n") 245 | return data_dict 246 | train_stats = get_stats("test[:80%]", "train.txt") 247 | test_stats = get_stats("test[80%:]", "text.txt") 248 | """ 249 | 250 | def download_data(self): 251 | rootdir = '/home/wangyabin/workspace/datasets/imagenet-r' 252 | 253 | train_txt = './utils/datautils/imagenet-r/train.txt' 254 | test_txt = './utils/datautils/imagenet-r/test.txt' 255 | 256 | train_images = [] 257 | train_labels = [] 258 | with open(train_txt, 'r') as dict_file: 259 | for line in dict_file: 260 | (key, value) = line.strip().split('\t') 261 | train_images.append(os.path.join(rootdir, value)) 262 | train_labels.append(int(key)) 263 | train_images = np.array(train_images) 264 | train_labels = np.array(train_labels) 265 | 266 | test_images = [] 267 | test_labels = [] 268 | with open(test_txt, 'r') as dict_file: 269 | for line in dict_file: 270 | (key, value) = line.strip().split('\t') 271 | test_images.append(os.path.join(rootdir, value)) 272 | test_labels.append(int(key)) 273 | test_images = np.array(test_images) 274 | test_labels = np.array(test_labels) 275 | 276 | self.train_data = train_images 277 | self.train_targets = train_labels 278 | self.test_data = test_images 279 | self.test_targets = test_labels 280 | 281 | class iCIFAR100_vit(iData): 282 | use_path = False 283 | train_trsf = [ 284 | transforms.Resize(256), 285 | transforms.ColorJitter(brightness=63 / 255), 286 | transforms.RandomResizedCrop(224), 287 | transforms.RandomHorizontalFlip(), 288 | # transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET), 289 | ] 290 | test_trsf = [ 291 | transforms.Resize(256), 292 | transforms.CenterCrop(224), 293 | ] 294 | common_trsf = [ 295 | transforms.ToTensor(), 296 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 297 | ] 298 | 299 | class_order = np.arange(100).tolist() 300 | # class_order = [51, 48, 73, 93, 39, 67, 29, 49, 57, 33, 4, 32, 5, 75, 63, 7, 61, 36, 69, 62, 46, 30, 25, 47, 12, 11, 94, 18, 27, 88, 0, 99, 21, 87, 34, 24, 86, 35, 22, 42, 66, 64, 2, 97, 98, 96, 71, 14, 95, 37, 54, 31, 10, 20, 52, 79, 60, 72, 41, 91, 44, 15, 16, 83, 59, 6, 82, 45, 81, 13, 53, 28, 50, 17, 19, 85, 1, 77, 70, 58, 38, 43, 80, 26, 9, 55, 92, 3, 89, 40, 76, 74, 65, 90, 84, 23, 8, 78, 56, 68] 301 | 302 | def download_data(self): 303 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True) 304 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True) 305 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 306 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 307 | 308 | class i5Datasets_vit(iData): 309 | use_path = False 310 | train_trsf = [ 311 | transforms.Resize(224), 312 | # transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET), 313 | transforms.RandomResizedCrop(224), 314 | transforms.RandomHorizontalFlip(), 315 | transforms.ColorJitter(brightness=63 / 255) 316 | ] 317 | test_trsf = [ 318 | transforms.Resize(256), 319 | transforms.CenterCrop(224), 320 | ] 321 | common_trsf = [ 322 | transforms.ToTensor(), 323 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 324 | ] 325 | 326 | 327 | class_order = np.arange(50).tolist() 328 | 329 | 330 | def download_data(self): 331 | img_size=64 332 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True) 333 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True) 334 | 335 | trainlist = [] 336 | testlist = [] 337 | train_label_list = [] 338 | test_label_list = [] 339 | 340 | # cifar10 341 | cifar10_train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True) 342 | cifar10_test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True) 343 | for img, target in zip(cifar10_train_dataset.data, cifar10_train_dataset.targets): 344 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)))) 345 | train_label_list.append(target) 346 | for img, target in zip(cifar10_test_dataset.data, cifar10_test_dataset.targets): 347 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)))) 348 | test_label_list.append(target) 349 | 350 | # MNIST 351 | minist_train_dataset = datasets.MNIST('./data', train=True, download=True) 352 | minist_test_dataset = datasets.MNIST('./data', train=False, download=True) 353 | for img, target in zip(minist_train_dataset.data.numpy(), minist_train_dataset.targets.numpy()): 354 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB'))) 355 | train_label_list.append(target+10) 356 | for img, target in zip(minist_test_dataset.data.numpy(), minist_test_dataset.targets.numpy()): 357 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB'))) 358 | test_label_list.append(target+10) 359 | 360 | # notMNIST 361 | classes = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] 362 | tarin_dir = "./data/notMNIST_large" 363 | test_dir = "./data/notMNIST_small" 364 | for idx, cls in enumerate(classes): 365 | image_files = os.listdir(os.path.join(tarin_dir, cls)) 366 | for img_path in image_files: 367 | try: 368 | image = np.array(Image.open(os.path.join(tarin_dir, cls, img_path)).resize((img_size, img_size)).convert('RGB')) 369 | trainlist.append(image) 370 | train_label_list.append(idx+20) 371 | except: 372 | print(os.path.join(tarin_dir, cls, img_path)) 373 | image_files = os.listdir(os.path.join(test_dir, cls)) 374 | for img_path in image_files: 375 | try: 376 | image = np.array(Image.open(os.path.join(test_dir, cls, img_path)).resize((img_size, img_size)).convert('RGB')) 377 | testlist.append(image) 378 | test_label_list.append(idx+20) 379 | except: 380 | print(os.path.join(test_dir, cls, img_path)) 381 | 382 | 383 | # Fashion-MNIST 384 | fminist_train_dataset = datasets.FashionMNIST('./data', train=True, download=True) 385 | fminist_test_dataset = datasets.FashionMNIST('./data', train=False, download=True) 386 | for img, target in zip(fminist_train_dataset.data.numpy(), fminist_train_dataset.targets.numpy()): 387 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB'))) 388 | train_label_list.append(target+30) 389 | for img, target in zip(fminist_test_dataset.data.numpy(), fminist_test_dataset.targets.numpy()): 390 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB'))) 391 | test_label_list.append(target+30) 392 | 393 | # SVHN 394 | svhn_train_dataset = datasets.SVHN('./data', split='train', download=True) 395 | svhn_test_dataset = datasets.SVHN('./data', split='test', download=True) 396 | for img, target in zip(svhn_train_dataset.data, svhn_train_dataset.labels): 397 | trainlist.append(np.array(Image.fromarray(img.transpose(1,2,0)).resize((img_size, img_size)))) 398 | train_label_list.append(target+40) 399 | for img, target in zip(svhn_test_dataset.data, svhn_test_dataset.labels): 400 | testlist.append(np.array(Image.fromarray(img.transpose(1,2,0)).resize((img_size, img_size)))) 401 | test_label_list.append(target+40) 402 | 403 | 404 | train_dataset.data = np.array(trainlist) 405 | train_dataset.targets = np.array(train_label_list) 406 | test_dataset.data = np.array(testlist) 407 | test_dataset.targets = np.array(test_label_list) 408 | 409 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets) 410 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets) 411 | 412 | class iGanFake(object): 413 | use_path = True 414 | train_trsf = [ 415 | transforms.RandomResizedCrop(224), 416 | transforms.RandomHorizontalFlip(), 417 | transforms.ColorJitter(brightness=63/255) 418 | ] 419 | test_trsf = [ 420 | transforms.Resize(256), 421 | transforms.CenterCrop(224), 422 | ] 423 | common_trsf = [ 424 | transforms.ToTensor(), 425 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 426 | ] 427 | 428 | def __init__(self, args): 429 | self.args = args 430 | class_order = args["class_order"] 431 | self.class_order = class_order 432 | 433 | def download_data(self): 434 | 435 | train_dataset = [] 436 | test_dataset = [] 437 | for id, name in enumerate(self.args["task_name"]): 438 | root_ = os.path.join(self.args["data_path"], name, 'train') 439 | sub_classes = os.listdir(root_) if self.args["multiclass"][id] else [''] 440 | for cls in sub_classes: 441 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')): 442 | train_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id)) 443 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')): 444 | train_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id)) 445 | 446 | for id, name in enumerate(self.args["task_name"]): 447 | root_ = os.path.join(self.args["data_path"], name, 'val') 448 | sub_classes = os.listdir(root_) if self.args["multiclass"][id] else [''] 449 | for cls in sub_classes: 450 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')): 451 | test_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id)) 452 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')): 453 | test_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id)) 454 | 455 | self.train_data, self.train_targets = split_images_labels(train_dataset) 456 | self.test_data, self.test_targets = split_images_labels(test_dataset) 457 | 458 | class iGanClass(object): 459 | use_path = True 460 | train_trsf = [ 461 | transforms.Resize(256), 462 | transforms.RandomResizedCrop(224), 463 | transforms.RandomHorizontalFlip(), 464 | transforms.ColorJitter(brightness=63/255) 465 | ] 466 | test_trsf = [ 467 | transforms.Resize(256), 468 | transforms.CenterCrop(224), 469 | ] 470 | common_trsf = [ 471 | transforms.ToTensor(), 472 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 473 | ] 474 | 475 | def __init__(self, args): 476 | self.args = args 477 | class_number = args["class_number"] 478 | self.class_order = [i for i in range(2*class_number)] 479 | self.task_name = [str(i) for i in range(class_number)] 480 | 481 | def download_data(self): 482 | 483 | train_dataset = [] 484 | test_dataset = [] 485 | for id, name in enumerate(self.task_name): 486 | root_ = os.path.join(self.args["data_path"], name, 'train') 487 | sub_classes = [''] 488 | for cls in sub_classes: 489 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')): 490 | train_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id)) 491 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')): 492 | train_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id)) 493 | 494 | for id, name in enumerate(self.task_name): 495 | root_ = os.path.join(self.args["data_path"], name, 'val') 496 | sub_classes = [''] 497 | for cls in sub_classes: 498 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')): 499 | test_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id)) 500 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')): 501 | test_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id)) 502 | 503 | self.train_data, self.train_targets = split_images_labels(train_dataset) 504 | self.test_data, self.test_targets = split_images_labels(test_dataset) -------------------------------------------------------------------------------- /models/convit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the CC-by-NC license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | '''These modules are adapted from those of timm, see 9 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 10 | ''' 11 | 12 | import copy 13 | import math 14 | from functools import lru_cache 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 20 | 21 | def freeze_parameters(m, requires_grad=False): 22 | if m is None: 23 | return 24 | 25 | if isinstance(m, nn.Parameter): 26 | m.requires_grad = requires_grad 27 | else: 28 | for p in m.parameters(): 29 | p.requires_grad = requires_grad 30 | 31 | class BatchEnsemble(nn.Module): 32 | def __init__(self, in_features, out_features, bias=True): 33 | super().__init__() 34 | 35 | self.linear = nn.Linear(in_features, out_features, bias=bias) 36 | 37 | self.out_features, self.in_features = out_features, in_features 38 | self.bias = bias 39 | 40 | self.r = nn.Parameter(torch.randn(self.out_features)) 41 | self.s = nn.Parameter(torch.randn(self.in_features)) 42 | 43 | def __deepcopy__(self, memo): 44 | cls = self.__class__ 45 | result = cls.__new__(cls, self.in_features, self.out_features, self.bias) 46 | memo[id(self)] = result 47 | for k, v in self.__dict__.items(): 48 | setattr(result, k, copy.deepcopy(v, memo)) 49 | 50 | result.linear.weight = self.linear.weight 51 | return result 52 | 53 | def reset_parameters(self): 54 | device = self.linear.weight.device 55 | self.r = nn.Parameter(torch.randn(self.out_features).to(device)) 56 | self.s = nn.Parameter(torch.randn(self.in_features).to(device)) 57 | 58 | if self.bias: 59 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight) 60 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 61 | nn.init.uniform_(self.linear.bias, -bound, bound) 62 | 63 | def forward(self, x): 64 | w = torch.outer(self.r, self.s) 65 | w = w * self.linear.weight 66 | return F.linear(x, w, self.linear.bias) 67 | 68 | 69 | class Mlp(nn.Module): 70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., fc=nn.Linear): 71 | super().__init__() 72 | out_features = out_features or in_features 73 | hidden_features = hidden_features or in_features 74 | self.fc1 = fc(in_features, hidden_features) 75 | self.act = act_layer() 76 | self.fc2 = fc(hidden_features, out_features) 77 | self.drop = nn.Dropout(drop) 78 | self.apply(self._init_weights) 79 | 80 | def _init_weights(self, m): 81 | if isinstance(m, nn.Linear): 82 | trunc_normal_(m.weight, std=.02) 83 | if isinstance(m, nn.Linear) and m.bias is not None: 84 | nn.init.constant_(m.bias, 0) 85 | elif isinstance(m, BatchEnsemble): 86 | trunc_normal_(m.linear.weight, std=.02) 87 | if isinstance(m.linear, nn.Linear) and m.linear.bias is not None: 88 | nn.init.constant_(m.linear.bias, 0) 89 | elif isinstance(m, nn.LayerNorm): 90 | nn.init.constant_(m.bias, 0) 91 | nn.init.constant_(m.weight, 1.0) 92 | 93 | def forward(self, x): 94 | x = self.fc1(x) 95 | x = self.act(x) 96 | x = self.drop(x) 97 | x = self.fc2(x) 98 | x = self.drop(x) 99 | return x 100 | 101 | 102 | class GPSA(nn.Module): 103 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 104 | locality_strength=1., use_local_init=True, fc=None): 105 | super().__init__() 106 | self.num_heads = num_heads 107 | self.dim = dim 108 | head_dim = dim // num_heads 109 | self.scale = qk_scale or head_dim ** -0.5 110 | 111 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) 112 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 113 | 114 | self.attn_drop = nn.Dropout(attn_drop) 115 | self.proj = nn.Linear(dim, dim) 116 | self.pos_proj = nn.Linear(3, num_heads) 117 | self.proj_drop = nn.Dropout(proj_drop) 118 | self.locality_strength = locality_strength 119 | self.gating_param = nn.Parameter(torch.ones(self.num_heads)) 120 | self.apply(self._init_weights) 121 | if use_local_init: 122 | self.local_init(locality_strength=locality_strength) 123 | 124 | def reset_parameters(self): 125 | self.apply(self._init_weights) 126 | 127 | def _init_weights(self, m): 128 | if isinstance(m, nn.Linear): 129 | trunc_normal_(m.weight, std=.02) 130 | if isinstance(m, nn.Linear) and m.bias is not None: 131 | nn.init.constant_(m.bias, 0) 132 | elif isinstance(m, nn.LayerNorm): 133 | nn.init.constant_(m.bias, 0) 134 | nn.init.constant_(m.weight, 1.0) 135 | 136 | def forward(self, x): 137 | B, N, C = x.shape 138 | if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N: 139 | self.get_rel_indices(N) 140 | 141 | attn = self.get_attention(x) 142 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 143 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 144 | x = self.proj(x) 145 | x = self.proj_drop(x) 146 | return x, attn, v 147 | 148 | def get_attention(self, x): 149 | B, N, C = x.shape 150 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 151 | q, k = qk[0], qk[1] 152 | pos_score = self.rel_indices.expand(B, -1, -1,-1) 153 | pos_score = self.pos_proj(pos_score).permute(0,3,1,2) 154 | patch_score = (q @ k.transpose(-2, -1)) * self.scale 155 | patch_score = patch_score.softmax(dim=-1) 156 | pos_score = pos_score.softmax(dim=-1) 157 | 158 | gating = self.gating_param.view(1,-1,1,1) 159 | attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score 160 | attn /= attn.sum(dim=-1).unsqueeze(-1) 161 | attn = self.attn_drop(attn) 162 | return attn 163 | 164 | def get_attention_map(self, x, return_map = False): 165 | 166 | attn_map = self.get_attention(x).mean(0) # average over batch 167 | distances = self.rel_indices.squeeze()[:,:,-1]**.5 168 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 169 | dist /= distances.size(0) 170 | if return_map: 171 | return dist, attn_map 172 | else: 173 | return dist 174 | 175 | def local_init(self, locality_strength=1.): 176 | self.v.weight.data.copy_(torch.eye(self.dim)) 177 | locality_distance = 1 #max(1,1/locality_strength**.5) 178 | 179 | kernel_size = int(self.num_heads**.5) 180 | center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2 181 | for h1 in range(kernel_size): 182 | for h2 in range(kernel_size): 183 | position = h1+kernel_size*h2 184 | self.pos_proj.weight.data[position,2] = -1 185 | self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance 186 | self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance 187 | self.pos_proj.weight.data *= locality_strength 188 | 189 | def get_rel_indices(self, num_patches): 190 | img_size = int(num_patches**.5) 191 | rel_indices = torch.zeros(1, num_patches, num_patches, 3) 192 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 193 | indx = ind.repeat(img_size,img_size) 194 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 195 | indd = indx**2 + indy**2 196 | rel_indices[:,:,:,2] = indd.unsqueeze(0) 197 | rel_indices[:,:,:,1] = indy.unsqueeze(0) 198 | rel_indices[:,:,:,0] = indx.unsqueeze(0) 199 | device = self.qk.weight.device 200 | self.rel_indices = rel_indices.to(device) 201 | 202 | 203 | class MHSA(nn.Module): 204 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=None): 205 | super().__init__() 206 | self.num_heads = num_heads 207 | head_dim = dim // num_heads 208 | self.scale = qk_scale or head_dim ** -0.5 209 | 210 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 211 | self.attn_drop = nn.Dropout(attn_drop) 212 | self.proj = nn.Linear(dim, dim) 213 | self.proj_drop = nn.Dropout(proj_drop) 214 | self.apply(self._init_weights) 215 | 216 | def reset_parameters(self): 217 | self.apply(self._init_weights) 218 | 219 | def _init_weights(self, m): 220 | if isinstance(m, nn.Linear): 221 | trunc_normal_(m.weight, std=.02) 222 | if isinstance(m, nn.Linear) and m.bias is not None: 223 | nn.init.constant_(m.bias, 0) 224 | elif isinstance(m, nn.LayerNorm): 225 | nn.init.constant_(m.bias, 0) 226 | nn.init.constant_(m.weight, 1.0) 227 | 228 | def get_attention_map(self, x, return_map = False): 229 | B, N, C = x.shape 230 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 231 | q, k, v = qkv[0], qkv[1], qkv[2] 232 | attn_map = (q @ k.transpose(-2, -1)) * self.scale 233 | attn_map = attn_map.softmax(dim=-1).mean(0) 234 | 235 | img_size = int(N**.5) 236 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 237 | indx = ind.repeat(img_size,img_size) 238 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 239 | indd = indx**2 + indy**2 240 | distances = indd**.5 241 | distances = distances.to('cuda') 242 | 243 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 244 | dist /= N 245 | 246 | if return_map: 247 | return dist, attn_map 248 | else: 249 | return dist 250 | 251 | def forward(self, x): 252 | B, N, C = x.shape 253 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 254 | q, k, v = qkv[0], qkv[1], qkv[2] 255 | 256 | attn = (q @ k.transpose(-2, -1)) * self.scale 257 | attn = attn.softmax(dim=-1) 258 | attn = self.attn_drop(attn) 259 | 260 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 261 | x = self.proj(x) 262 | x = self.proj_drop(x) 263 | return x, attn 264 | 265 | 266 | class ScaleNorm(nn.Module): 267 | """See 268 | https://github.com/lucidrains/reformer-pytorch/blob/a751fe2eb939dcdd81b736b2f67e745dc8472a09/reformer_pytorch/reformer_pytorch.py#L143 269 | """ 270 | def __init__(self, dim, eps=1e-5): 271 | super().__init__() 272 | self.g = nn.Parameter(torch.ones(1)) 273 | self.eps = eps 274 | 275 | def forward(self, x): 276 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps) 277 | return x / n * self.g 278 | 279 | 280 | class Block(nn.Module): 281 | 282 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 283 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type=GPSA, 284 | fc=nn.Linear, **kwargs): 285 | super().__init__() 286 | self.norm1 = norm_layer(dim) 287 | self.attn = attention_type(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, fc=fc, **kwargs) 288 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 289 | self.norm2 = norm_layer(dim) 290 | mlp_hidden_dim = int(dim * mlp_ratio) 291 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, fc=fc) 292 | 293 | def reset_parameters(self): 294 | self.norm1.reset_parameters() 295 | self.norm2.reset_parameters() 296 | self.attn.reset_parameters() 297 | self.mlp.apply(self.mlp._init_weights) 298 | 299 | def forward(self, x, mask_heads=None, task_index=1, attn_mask=None): 300 | if isinstance(self.attn, ClassAttention): # Like in CaiT 301 | cls_token = x[:, :task_index] 302 | 303 | xx = self.norm1(x) 304 | xx, attn, v = self.attn( 305 | xx, 306 | mask_heads=mask_heads, 307 | nb=task_index, 308 | attn_mask=attn_mask 309 | ) 310 | 311 | cls_token = self.drop_path(xx[:, :task_index]) + cls_token 312 | cls_token = self.drop_path(self.mlp(self.norm2(cls_token))) + cls_token 313 | 314 | return cls_token, attn, v 315 | 316 | xx = self.norm1(x) 317 | xx, attn, v = self.attn(xx) 318 | 319 | x = self.drop_path(xx) + x 320 | x = self.drop_path(self.mlp(self.norm2(x))) + x 321 | 322 | return x, attn, v 323 | 324 | 325 | class ClassAttention(nn.Module): 326 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 327 | # with slight modifications to do CA 328 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=nn.Linear): 329 | super().__init__() 330 | self.num_heads = num_heads 331 | head_dim = dim // num_heads 332 | self.scale = qk_scale or head_dim ** -0.5 333 | 334 | self.q = fc(dim, dim, bias=qkv_bias) 335 | self.k = fc(dim, dim, bias=qkv_bias) 336 | self.v = fc(dim, dim, bias=qkv_bias) 337 | self.attn_drop = nn.Dropout(attn_drop) 338 | self.proj = fc(dim, dim) 339 | self.proj_drop = nn.Dropout(proj_drop) 340 | 341 | self.apply(self._init_weights) 342 | 343 | def reset_parameters(self): 344 | self.apply(self._init_weights) 345 | 346 | def _init_weights(self, m): 347 | if isinstance(m, nn.Linear): 348 | trunc_normal_(m.weight, std=.02) 349 | if isinstance(m, nn.Linear) and m.bias is not None: 350 | nn.init.constant_(m.bias, 0) 351 | elif isinstance(m, nn.LayerNorm): 352 | nn.init.constant_(m.bias, 0) 353 | nn.init.constant_(m.weight, 1.0) 354 | 355 | def forward(self, x, mask_heads=None, **kwargs): 356 | B, N, C = x.shape 357 | q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 358 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 359 | 360 | q = q * self.scale 361 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 362 | 363 | attn = (q @ k.transpose(-2, -1)) 364 | attn = attn.softmax(dim=-1) 365 | attn = self.attn_drop(attn) 366 | 367 | if mask_heads is not None: 368 | mask_heads = mask_heads.expand(B, self.num_heads, -1, N) 369 | attn = attn * mask_heads 370 | 371 | x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) 372 | x_cls = self.proj(x_cls) 373 | x_cls = self.proj_drop(x_cls) 374 | 375 | return x_cls, attn, v 376 | 377 | 378 | class JointCA(ClassAttention): 379 | """Forward all task tokens together. 380 | 381 | It uses a masked attention so that task tokens don't interact between them. 382 | It should have the same results as independent forward per task token but being 383 | much faster. 384 | 385 | HOWEVER, it works a bit worse (like ~2pts less in 'all top-1' CIFAR100 50 steps). 386 | So if anyone knows why, please tell me! 387 | """ 388 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=nn.Linear): 389 | super().__init__() 390 | self.num_heads = num_heads 391 | head_dim = dim // num_heads 392 | self.scale = qk_scale or head_dim ** -0.5 393 | 394 | self.q = fc(dim, dim, bias=qkv_bias) 395 | self.k = fc(dim, dim, bias=qkv_bias) 396 | self.v = fc(dim, dim, bias=qkv_bias) 397 | self.attn_drop = nn.Dropout(attn_drop) 398 | self.proj = fc(dim, dim) 399 | self.proj_drop = nn.Dropout(proj_drop) 400 | 401 | self.apply(self._init_weights) 402 | 403 | def reset_parameters(self): 404 | self.apply(self._init_weights) 405 | 406 | def _init_weights(self, m): 407 | if isinstance(m, nn.Linear): 408 | trunc_normal_(m.weight, std=.02) 409 | if isinstance(m, nn.Linear) and m.bias is not None: 410 | nn.init.constant_(m.bias, 0) 411 | elif isinstance(m, nn.LayerNorm): 412 | nn.init.constant_(m.bias, 0) 413 | nn.init.constant_(m.weight, 1.0) 414 | 415 | @lru_cache(maxsize=1) 416 | def get_attention_mask(self, attn_shape, nb_task_tokens): 417 | """Mask so that task tokens don't interact together. 418 | 419 | Given two task tokens (t1, t2) and three patch tokens (p1, p2, p3), the 420 | attention matrix is: 421 | 422 | t1-t1 t1-t2 t1-p1 t1-p2 t1-p3 423 | t2-t1 t2-t2 t2-p1 t2-p2 t2-p3 424 | 425 | So that the mask (True values are deleted) should be: 426 | 427 | False True False False False 428 | True False False False False 429 | """ 430 | mask = torch.zeros(attn_shape, dtype=torch.bool) 431 | for i in range(nb_task_tokens): 432 | mask[:, i, :i] = True 433 | mask[:, i, i+1:nb_task_tokens] = True 434 | return mask 435 | 436 | def forward(self, x, attn_mask=None, nb_task_tokens=1): 437 | B, N, C = x.shape 438 | q = self.q(x[:,:nb_task_tokens]).reshape(B, nb_task_tokens, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 439 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 440 | 441 | q = q * self.scale 442 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 443 | 444 | attn = (q @ k.transpose(-2, -1)) 445 | if attn_mask is not None: 446 | mask = self.get_attention_mask(attn.shape, nb_task_tokens, attn_mask) 447 | attn[mask] = -float('inf') 448 | attn = attn.softmax(dim=-1) 449 | 450 | attn = self.attn_drop(attn) 451 | 452 | x_cls = (attn @ v).transpose(1, 2).reshape(B, nb_task_tokens, C) 453 | x_cls = self.proj(x_cls) 454 | x_cls = self.proj_drop(x_cls) 455 | 456 | return x_cls, attn, v 457 | 458 | 459 | class PatchEmbed(nn.Module): 460 | """ Image to Patch Embedding, from timm 461 | """ 462 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 463 | super().__init__() 464 | img_size = to_2tuple(img_size) 465 | patch_size = to_2tuple(patch_size) 466 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 467 | self.img_size = img_size 468 | self.patch_size = patch_size 469 | self.num_patches = num_patches 470 | 471 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 472 | self.apply(self._init_weights) 473 | 474 | def reset_parameters(self): 475 | self.apply(self._init_weights) 476 | 477 | def _init_weights(self, m): 478 | if isinstance(m, nn.Linear): 479 | trunc_normal_(m.weight, std=.02) 480 | if isinstance(m, nn.Linear) and m.bias is not None: 481 | nn.init.constant_(m.bias, 0) 482 | elif isinstance(m, nn.LayerNorm): 483 | nn.init.constant_(m.bias, 0) 484 | nn.init.constant_(m.weight, 1.0) 485 | 486 | def forward(self, x): 487 | B, C, H, W = x.shape 488 | #assert H == self.img_size[0] and W == self.img_size[1], \ 489 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 490 | x = self.proj(x).flatten(2).transpose(1, 2) 491 | return x 492 | 493 | 494 | 495 | class HybridEmbed(nn.Module): 496 | """ CNN Feature Map Embedding, from timm 497 | """ 498 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 499 | super().__init__() 500 | assert isinstance(backbone, nn.Module) 501 | img_size = to_2tuple(img_size) 502 | self.img_size = img_size 503 | self.backbone = backbone 504 | if feature_size is None: 505 | with torch.no_grad(): 506 | training = backbone.training 507 | if training: 508 | backbone.eval() 509 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 510 | feature_size = o.shape[-2:] 511 | feature_dim = o.shape[1] 512 | backbone.train(training) 513 | else: 514 | feature_size = to_2tuple(feature_size) 515 | feature_dim = self.backbone.feature_info.channels()[-1] 516 | self.num_patches = feature_size[0] * feature_size[1] 517 | self.proj = nn.Linear(feature_dim, embed_dim) 518 | self.apply(self._init_weights) 519 | 520 | def reset_parameters(self): 521 | self.apply(self._init_weights) 522 | 523 | def forward(self, x): 524 | x = self.backbone(x)[-1] 525 | x = x.flatten(2).transpose(1, 2) 526 | x = self.proj(x) 527 | return x 528 | 529 | 530 | class ConVit(nn.Module): 531 | """ Vision Transformer with support for patch or hybrid CNN input stage 532 | """ 533 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 534 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 535 | drop_path_rate=0., hybrid_backbone=None, norm_layer='layer', 536 | local_up_to_layer=3, locality_strength=1., use_pos_embed=True, 537 | class_attention=False, ca_type='base', 538 | ): 539 | super().__init__() 540 | self.num_classes = num_classes 541 | self.num_heads = num_heads 542 | self.embed_dim = embed_dim 543 | self.local_up_to_layer = local_up_to_layer 544 | self.num_features = self.final_dim = self.embed_dim = embed_dim # num_features for consistency with other models_2 545 | self.locality_strength = locality_strength 546 | self.use_pos_embed = use_pos_embed 547 | 548 | if norm_layer == 'layer': 549 | norm_layer = nn.LayerNorm 550 | elif norm_layer == 'scale': 551 | norm_layer = ScaleNorm 552 | else: 553 | raise NotImplementedError(f'Unknown normalization {norm_layer}') 554 | 555 | if hybrid_backbone is not None: 556 | self.patch_embed = HybridEmbed( 557 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 558 | else: 559 | self.patch_embed = PatchEmbed( 560 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 561 | num_patches = self.patch_embed.num_patches 562 | self.num_patches = num_patches 563 | 564 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 565 | self.pos_drop = nn.Dropout(p=drop_rate) 566 | 567 | if self.use_pos_embed: 568 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 569 | trunc_normal_(self.pos_embed, std=.02) 570 | 571 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 572 | 573 | blocks = [] 574 | 575 | if ca_type == 'base': 576 | ca_block = ClassAttention 577 | elif ca_type == 'jointca': 578 | ca_block = JointCA 579 | else: 580 | raise ValueError(f'Unknown CA type {ca_type}') 581 | 582 | for layer_index in range(depth): 583 | if layer_index < local_up_to_layer: 584 | # Convit 585 | block = Block( 586 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 587 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer, 588 | attention_type=GPSA, locality_strength=locality_strength 589 | ) 590 | elif not class_attention: 591 | # Convit 592 | block = Block( 593 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 594 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer, 595 | attention_type=MHSA 596 | ) 597 | else: 598 | # CaiT 599 | block = Block( 600 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 601 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer, 602 | attention_type=ca_block 603 | ) 604 | 605 | blocks.append(block) 606 | 607 | self.blocks = nn.ModuleList(blocks) 608 | self.norm = norm_layer(embed_dim) 609 | self.use_class_attention = class_attention 610 | 611 | # Classifier head 612 | self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] 613 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 614 | 615 | trunc_normal_(self.cls_token, std=.02) 616 | self.head.apply(self._init_weights) 617 | 618 | def freeze(self, names): 619 | for name in names: 620 | if name == 'all': 621 | return freeze_parameters(self) 622 | elif name == 'old_heads': 623 | self.head.freeze(name) 624 | elif name == 'backbone': 625 | freeze_parameters(self.blocks) 626 | freeze_parameters(self.patch_embed) 627 | freeze_parameters(self.pos_embed) 628 | freeze_parameters(self.norm) 629 | else: 630 | raise NotImplementedError(f'Unknown name={name}.') 631 | 632 | def reset_classifier(self): 633 | self.head.apply(self._init_weights) 634 | 635 | def reset_parameters(self): 636 | for b in self.blocks: 637 | b.reset_parameters() 638 | self.norm.reset_parameters() 639 | self.head.apply(self._init_weights) 640 | 641 | def _init_weights(self, m): 642 | if isinstance(m, nn.Linear): 643 | trunc_normal_(m.weight, std=.02) 644 | if isinstance(m, nn.Linear) and m.bias is not None: 645 | nn.init.constant_(m.bias, 0) 646 | elif isinstance(m, nn.LayerNorm): 647 | nn.init.constant_(m.bias, 0) 648 | nn.init.constant_(m.weight, 1.0) 649 | 650 | def get_internal_losses(self, clf_loss): 651 | return {} 652 | 653 | def end_finetuning(self): 654 | pass 655 | 656 | def begin_finetuning(self): 657 | pass 658 | 659 | def epoch_log(self): 660 | return {} 661 | 662 | @torch.jit.ignore 663 | def no_weight_decay(self): 664 | return {'pos_embed', 'cls_token'} 665 | 666 | def get_classifier(self): 667 | return self.head 668 | 669 | def forward_sa(self, x): 670 | B = x.shape[0] 671 | x = self.patch_embed(x) 672 | 673 | if self.use_pos_embed: 674 | x = x + self.pos_embed 675 | x = self.pos_drop(x) 676 | 677 | for blk in self.blocks[:self.local_up_to_layer]: 678 | x, _ = blk(x) 679 | 680 | return x 681 | 682 | def forward_features(self, x, final_norm=True): 683 | B = x.shape[0] 684 | x = self.patch_embed(x) 685 | 686 | cls_tokens = self.cls_token.expand(B, -1, -1) 687 | 688 | if self.use_pos_embed: 689 | x = x + self.pos_embed 690 | x = self.pos_drop(x) 691 | 692 | for blk in self.blocks[:self.local_up_to_layer]: 693 | x, _, _ = blk(x) 694 | 695 | if self.use_class_attention: 696 | for blk in self.blocks[self.local_up_to_layer:]: 697 | cls_tokens, _, _ = blk(torch.cat((cls_tokens, x), dim=1)) 698 | else: 699 | x = torch.cat((cls_tokens, x), dim=1) 700 | for blk in self.blocks[self.local_up_to_layer:]: 701 | x, _ , _ = blk(x) 702 | 703 | if final_norm: 704 | if self.use_class_attention: 705 | cls_tokens = self.norm(cls_tokens) 706 | else: 707 | x = self.norm(x) 708 | 709 | if self.use_class_attention: 710 | return cls_tokens[:, 0], None, None 711 | else: 712 | return x[:, 0], None, None 713 | 714 | def forward(self, x): 715 | x = self.forward_features(x)[0] 716 | x = self.head(x) 717 | return x 718 | 719 | 720 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | A PyTorch implement of Vision Transformers as described in: 3 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 4 | - https://arxiv.org/abs/2010.11929 5 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 6 | - https://arxiv.org/abs/2106.10270 7 | The official jax code is released and available at https://github.com/google-research/vision_transformer 8 | Acknowledgments: 9 | * The paper authors for releasing code and weights, thanks! 10 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 11 | for some einops/einsum fun 12 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 13 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 14 | Hacked together by / Copyright 2020, Ross Wightman 15 | """ 16 | import math 17 | import logging 18 | from functools import partial 19 | from collections import OrderedDict 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.utils.checkpoint 25 | 26 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 27 | from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq 28 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 29 | from timm.models.registry import register_model 30 | 31 | _logger = logging.getLogger(__name__) 32 | 33 | 34 | def _cfg(url='', **kwargs): 35 | return { 36 | 'url': url, 37 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 38 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 39 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 40 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 41 | **kwargs 42 | } 43 | 44 | 45 | default_cfgs = { 46 | # patch engine (weights from official Google JAX impl) 47 | 'vit_tiny_patch16_224': _cfg( 48 | url='https://storage.googleapis.com/vit_models/augreg/' 49 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 50 | 'vit_tiny_patch16_384': _cfg( 51 | url='https://storage.googleapis.com/vit_models/augreg/' 52 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 53 | input_size=(3, 384, 384), crop_pct=1.0), 54 | 'vit_small_patch32_224': _cfg( 55 | url='https://storage.googleapis.com/vit_models/augreg/' 56 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 57 | 'vit_small_patch32_384': _cfg( 58 | url='https://storage.googleapis.com/vit_models/augreg/' 59 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 60 | input_size=(3, 384, 384), crop_pct=1.0), 61 | 'vit_small_patch16_224': _cfg( 62 | url='https://storage.googleapis.com/vit_models/augreg/' 63 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 64 | 'vit_small_patch16_384': _cfg( 65 | url='https://storage.googleapis.com/vit_models/augreg/' 66 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 67 | input_size=(3, 384, 384), crop_pct=1.0), 68 | 'vit_base_patch32_224': _cfg( 69 | url='https://storage.googleapis.com/vit_models/augreg/' 70 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 71 | 'vit_base_patch32_384': _cfg( 72 | url='https://storage.googleapis.com/vit_models/augreg/' 73 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 74 | input_size=(3, 384, 384), crop_pct=1.0), 75 | 'vit_base_patch16_224': _cfg( 76 | url='https://storage.googleapis.com/vit_models/augreg/' 77 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 78 | 'vit_base_patch16_384': _cfg( 79 | url='https://storage.googleapis.com/vit_models/augreg/' 80 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 81 | input_size=(3, 384, 384), crop_pct=1.0), 82 | 'vit_base_patch8_224': _cfg( 83 | url='https://storage.googleapis.com/vit_models/augreg/' 84 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 85 | 'vit_large_patch32_224': _cfg( 86 | url='', # no official model weights for this combo, only for in21k 87 | ), 88 | 'vit_large_patch32_384': _cfg( 89 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 90 | input_size=(3, 384, 384), crop_pct=1.0), 91 | 'vit_large_patch16_224': _cfg( 92 | url='https://storage.googleapis.com/vit_models/augreg/' 93 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 94 | 'vit_large_patch16_384': _cfg( 95 | url='https://storage.googleapis.com/vit_models/augreg/' 96 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 97 | input_size=(3, 384, 384), crop_pct=1.0), 98 | 99 | 'vit_large_patch14_224': _cfg(url=''), 100 | 'vit_huge_patch14_224': _cfg(url=''), 101 | 'vit_giant_patch14_224': _cfg(url=''), 102 | 'vit_gigantic_patch14_224': _cfg(url=''), 103 | 104 | 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 105 | 106 | # patch engine, imagenet21k (weights from official Google JAX impl) 107 | 'vit_tiny_patch16_224_in21k': _cfg( 108 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 109 | num_classes=21843), 110 | 'vit_small_patch32_224_in21k': _cfg( 111 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 112 | num_classes=21843), 113 | 'vit_small_patch16_224_in21k': _cfg( 114 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 115 | num_classes=21843), 116 | 'vit_base_patch32_224_in21k': _cfg( 117 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 118 | num_classes=21843), 119 | 'vit_base_patch16_224_in21k': _cfg( 120 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 121 | num_classes=21843), 122 | 'vit_base_patch8_224_in21k': _cfg( 123 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 124 | num_classes=21843), 125 | 'vit_large_patch32_224_in21k': _cfg( 126 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 127 | num_classes=21843), 128 | 'vit_large_patch16_224_in21k': _cfg( 129 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 130 | num_classes=21843), 131 | 'vit_huge_patch14_224_in21k': _cfg( 132 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 133 | hf_hub_id='timm/vit_huge_patch14_224_in21k', 134 | num_classes=21843), 135 | 136 | # SAM trained engine (https://arxiv.org/abs/2106.01548) 137 | 'vit_base_patch32_224_sam': _cfg( 138 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), 139 | 'vit_base_patch16_224_sam': _cfg( 140 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), 141 | 142 | # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) 143 | 'vit_small_patch16_224_dino': _cfg( 144 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', 145 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 146 | 'vit_small_patch8_224_dino': _cfg( 147 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', 148 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 149 | 'vit_base_patch16_224_dino': _cfg( 150 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', 151 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 152 | 'vit_base_patch8_224_dino': _cfg( 153 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', 154 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 155 | 156 | 157 | # ViT ImageNet-21K-P pretraining by MILL 158 | 'vit_base_patch16_224_miil_in21k': _cfg( 159 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 160 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 161 | ), 162 | 'vit_base_patch16_224_miil': _cfg( 163 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 164 | '/vit_base_patch16_224_1k_miil_84_4.pth', 165 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 166 | ), 167 | 168 | # experimental 169 | 'vit_small_patch16_36x1_224': _cfg(url=''), 170 | 'vit_small_patch16_18x2_224': _cfg(url=''), 171 | 'vit_base_patch16_18x2_224': _cfg(url=''), 172 | } 173 | 174 | 175 | class Attention(nn.Module): 176 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 177 | super().__init__() 178 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 179 | self.num_heads = num_heads 180 | head_dim = dim // num_heads 181 | self.scale = head_dim ** -0.5 182 | 183 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 184 | self.attn_drop = nn.Dropout(attn_drop) 185 | self.proj = nn.Linear(dim, dim) 186 | self.proj_drop = nn.Dropout(proj_drop) 187 | 188 | def forward(self, x): 189 | B, N, C = x.shape 190 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 191 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 192 | 193 | attn = (q @ k.transpose(-2, -1)) * self.scale 194 | attn = attn.softmax(dim=-1) 195 | attn = self.attn_drop(attn) 196 | 197 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 198 | x = self.proj(x) 199 | x = self.proj_drop(x) 200 | return x 201 | 202 | 203 | class LayerScale(nn.Module): 204 | def __init__(self, dim, init_values=1e-5, inplace=False): 205 | super().__init__() 206 | self.inplace = inplace 207 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 208 | 209 | def forward(self, x): 210 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 211 | 212 | 213 | class Block(nn.Module): 214 | 215 | def __init__( 216 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, 217 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 218 | super().__init__() 219 | self.norm1 = norm_layer(dim) 220 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 221 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 222 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 223 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 224 | 225 | self.norm2 = norm_layer(dim) 226 | mlp_hidden_dim = int(dim * mlp_ratio) 227 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 228 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 229 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 230 | 231 | def forward(self, x): 232 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 233 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 234 | return x 235 | 236 | 237 | class ParallelBlock(nn.Module): 238 | 239 | def __init__( 240 | self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, 241 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 242 | super().__init__() 243 | self.num_parallel = num_parallel 244 | self.attns = nn.ModuleList() 245 | self.ffns = nn.ModuleList() 246 | for _ in range(num_parallel): 247 | self.attns.append(nn.Sequential(OrderedDict([ 248 | ('norm', norm_layer(dim)), 249 | ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), 250 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 251 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 252 | ]))) 253 | self.ffns.append(nn.Sequential(OrderedDict([ 254 | ('norm', norm_layer(dim)), 255 | ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), 256 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), 257 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) 258 | ]))) 259 | 260 | def _forward_jit(self, x): 261 | x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) 262 | x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) 263 | return x 264 | 265 | @torch.jit.ignore 266 | def _forward(self, x): 267 | x = x + sum(attn(x) for attn in self.attns) 268 | x = x + sum(ffn(x) for ffn in self.ffns) 269 | return x 270 | 271 | def forward(self, x): 272 | if torch.jit.is_scripting() or torch.jit.is_tracing(): 273 | return self._forward_jit(x) 274 | else: 275 | return self._forward(x) 276 | 277 | 278 | class VisionTransformer(nn.Module): 279 | """ Vision Transformer 280 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 281 | - https://arxiv.org/abs/2010.11929 282 | """ 283 | 284 | def __init__( 285 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', 286 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, 287 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, 288 | embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): 289 | """ 290 | Args: 291 | img_size (int, tuple): input image size 292 | patch_size (int, tuple): patch size 293 | in_chans (int): number of input channels 294 | num_classes (int): number of classes for classification head 295 | global_pool (str): type of global pooling for final sequence (default: 'token') 296 | embed_dim (int): embedding dimension 297 | depth (int): depth of transformer 298 | num_heads (int): number of attention heads 299 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 300 | qkv_bias (bool): enable bias for qkv if True 301 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 302 | drop_rate (float): dropout rate 303 | attn_drop_rate (float): attention dropout rate 304 | drop_path_rate (float): stochastic depth rate 305 | weight_init: (str): weight init scheme 306 | init_values: (float): layer-scale init values 307 | embed_layer (nn.Module): patch embedding layer 308 | norm_layer: (nn.Module): normalization layer 309 | act_layer: (nn.Module): MLP activation layer 310 | """ 311 | super().__init__() 312 | assert global_pool in ('', 'avg', 'token') 313 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 314 | act_layer = act_layer or nn.GELU 315 | 316 | self.num_classes = num_classes 317 | self.global_pool = global_pool 318 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other engine 319 | self.num_tokens = 1 320 | self.grad_checkpointing = False 321 | 322 | self.patch_embed = embed_layer( 323 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 324 | num_patches = self.patch_embed.num_patches 325 | 326 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 327 | # self.cls_token_grow = nn.Parameter(torch.zeros(1, 5000, embed_dim)) 328 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 329 | # self.pos_embed_grow = nn.Parameter(torch.zeros(1, num_patches + 1000, embed_dim)) 330 | self.pos_drop = nn.Dropout(p=drop_rate) 331 | 332 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 333 | self.blocks = nn.Sequential(*[ 334 | block_fn( 335 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, 336 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 337 | for i in range(depth)]) 338 | use_fc_norm = self.global_pool == 'avg' 339 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 340 | 341 | # Representation layer. Used for original ViT engine w/ in21k pretraining. 342 | self.representation_size = representation_size 343 | self.pre_logits = nn.Identity() 344 | if representation_size: 345 | self._reset_representation(representation_size) 346 | 347 | # Classifier Head 348 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 349 | final_chs = self.representation_size if self.representation_size else self.embed_dim 350 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 351 | self.out_dim = final_chs 352 | 353 | if weight_init != 'skip': 354 | self.init_weights(weight_init) 355 | 356 | def _reset_representation(self, representation_size): 357 | self.representation_size = representation_size 358 | if self.representation_size: 359 | self.pre_logits = nn.Sequential(OrderedDict([ 360 | ('fc', nn.Linear(self.embed_dim, self.representation_size)), 361 | ('act', nn.Tanh()) 362 | ])) 363 | else: 364 | self.pre_logits = nn.Identity() 365 | 366 | def init_weights(self, mode=''): 367 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 368 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 369 | trunc_normal_(self.pos_embed, std=.02) 370 | # trunc_normal_(self.pos_embed_grow, std=.02) 371 | nn.init.normal_(self.cls_token, std=1e-6) 372 | # nn.init.normal_(self.cls_token_grow, std=1e-6) 373 | named_apply(get_init_weights_vit(mode, head_bias), self) 374 | 375 | def _init_weights(self, m): 376 | # this fn left here for compat with downstream users 377 | init_weights_vit_timm(m) 378 | 379 | @torch.jit.ignore() 380 | def load_pretrained(self, checkpoint_path, prefix=''): 381 | _load_weights(self, checkpoint_path, prefix) 382 | 383 | @torch.jit.ignore 384 | def no_weight_decay(self): 385 | return {'pos_embed', 'cls_token', 'dist_token'} 386 | 387 | @torch.jit.ignore 388 | def group_matcher(self, coarse=False): 389 | return dict( 390 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 391 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 392 | ) 393 | 394 | @torch.jit.ignore 395 | def set_grad_checkpointing(self, enable=True): 396 | self.grad_checkpointing = enable 397 | 398 | @torch.jit.ignore 399 | def get_classifier(self): 400 | return self.head 401 | 402 | def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): 403 | self.num_classes = num_classes 404 | if global_pool is not None: 405 | assert global_pool in ('', 'avg', 'token') 406 | self.global_pool = global_pool 407 | if representation_size is not None: 408 | self._reset_representation(representation_size) 409 | final_chs = self.representation_size if self.representation_size else self.embed_dim 410 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() 411 | 412 | def forward_features(self, x): 413 | x = self.patch_embed(x) 414 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 415 | 416 | x = self.pos_drop(x + self.pos_embed) 417 | if self.grad_checkpointing and not torch.jit.is_scripting(): 418 | x = checkpoint_seq(self.blocks, x) 419 | else: 420 | x = self.blocks(x) 421 | x = self.norm(x) 422 | return x 423 | 424 | def forward_features_grow(self, x, class_num): 425 | x = self.patch_embed(x) 426 | # x = torch.cat((self.cls_token_grow[:, :class_num, :].expand(x.shape[0], -1, -1), self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 427 | # x = self.pos_drop(x + self.pos_embed_grow[:, :self.patch_embed.num_patches+class_num, :]) 428 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 429 | x = self.pos_drop(x + self.pos_embed) 430 | x = torch.cat((self.cls_token_grow[:, :class_num*2, :].expand(x.shape[0], -1, -1), x), dim=1) 431 | 432 | # import pdb;pdb.set_trace() 433 | if self.grad_checkpointing and not torch.jit.is_scripting(): 434 | x = checkpoint_seq(self.blocks, x) 435 | else: 436 | x = self.blocks(x) 437 | x = self.norm(x) 438 | return x 439 | 440 | def forward_head(self, x, pre_logits: bool = False): 441 | if self.global_pool: 442 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 443 | x = self.fc_norm(x) 444 | x = self.pre_logits(x) 445 | return x if pre_logits else self.head(x) 446 | 447 | def forward(self, x, grow_flag=False, numcls=0): 448 | if not grow_flag: 449 | x = self.forward_features(x) 450 | else: 451 | x = self.forward_features_grow(x, numcls) 452 | 453 | if self.global_pool: 454 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 455 | x = self.fc_norm(x) 456 | return { 457 | 'fmaps': [x], 458 | 'features': x 459 | } 460 | 461 | 462 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 463 | """ ViT weight initialization, original timm impl (for reproducibility) """ 464 | if isinstance(module, nn.Linear): 465 | trunc_normal_(module.weight, std=.02) 466 | if module.bias is not None: 467 | nn.init.zeros_(module.bias) 468 | 469 | 470 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): 471 | """ ViT weight initialization, matching JAX (Flax) impl """ 472 | if isinstance(module, nn.Linear): 473 | if name.startswith('head'): 474 | nn.init.zeros_(module.weight) 475 | nn.init.constant_(module.bias, head_bias) 476 | elif name.startswith('pre_logits'): 477 | lecun_normal_(module.weight) 478 | nn.init.zeros_(module.bias) 479 | else: 480 | nn.init.xavier_uniform_(module.weight) 481 | if module.bias is not None: 482 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) 483 | elif isinstance(module, nn.Conv2d): 484 | lecun_normal_(module.weight) 485 | if module.bias is not None: 486 | nn.init.zeros_(module.bias) 487 | 488 | 489 | def init_weights_vit_moco(module: nn.Module, name: str = ''): 490 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ 491 | if isinstance(module, nn.Linear): 492 | if 'qkv' in name: 493 | # treat the weights of Q, K, V separately 494 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) 495 | nn.init.uniform_(module.weight, -val, val) 496 | else: 497 | nn.init.xavier_uniform_(module.weight) 498 | if module.bias is not None: 499 | nn.init.zeros_(module.bias) 500 | 501 | 502 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 503 | if 'jax' in mode: 504 | return partial(init_weights_vit_jax, head_bias=head_bias) 505 | elif 'moco' in mode: 506 | return init_weights_vit_moco 507 | else: 508 | return init_weights_vit_timm 509 | 510 | 511 | @torch.no_grad() 512 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 513 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 514 | """ 515 | import numpy as np 516 | 517 | def _n2p(w, t=True): 518 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 519 | w = w.flatten() 520 | if t: 521 | if w.ndim == 4: 522 | w = w.transpose([3, 2, 0, 1]) 523 | elif w.ndim == 3: 524 | w = w.transpose([2, 0, 1]) 525 | elif w.ndim == 2: 526 | w = w.transpose([1, 0]) 527 | return torch.from_numpy(w) 528 | 529 | w = np.load(checkpoint_path) 530 | if not prefix and 'opt/target/embedding/kernel' in w: 531 | prefix = 'opt/target/' 532 | 533 | if hasattr(model.patch_embed, 'backbone'): 534 | # hybrid 535 | backbone = model.patch_embed.backbone 536 | stem_only = not hasattr(backbone, 'stem') 537 | stem = backbone if stem_only else backbone.stem 538 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 539 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 540 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 541 | if not stem_only: 542 | for i, stage in enumerate(backbone.stages): 543 | for j, block in enumerate(stage.blocks): 544 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 545 | for r in range(3): 546 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 547 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 548 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 549 | if block.downsample is not None: 550 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 551 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 552 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 553 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 554 | else: 555 | embed_conv_w = adapt_input_conv( 556 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 557 | model.patch_embed.proj.weight.copy_(embed_conv_w) 558 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 559 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 560 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 561 | if pos_embed_w.shape != model.pos_embed.shape: 562 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 563 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 564 | model.pos_embed.copy_(pos_embed_w) 565 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 566 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 567 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 568 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 569 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 570 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 571 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 572 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 573 | for i, block in enumerate(model.blocks.children()): 574 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 575 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 576 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 577 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 578 | block.attn.qkv.weight.copy_(torch.cat([ 579 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 580 | block.attn.qkv.bias.copy_(torch.cat([ 581 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 582 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 583 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 584 | for r in range(2): 585 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 586 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 587 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 588 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 589 | 590 | 591 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 592 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 593 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 594 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 595 | ntok_new = posemb_new.shape[1] 596 | if num_tokens: 597 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 598 | ntok_new -= num_tokens 599 | else: 600 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 601 | gs_old = int(math.sqrt(len(posemb_grid))) 602 | if not len(gs_new): # backwards compatibility 603 | gs_new = [int(math.sqrt(ntok_new))] * 2 604 | assert len(gs_new) >= 2 605 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 606 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 607 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 608 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 609 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 610 | return posemb 611 | 612 | 613 | def checkpoint_filter_fn(state_dict, model): 614 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 615 | out_dict = {} 616 | if 'model' in state_dict: 617 | # For deit engine 618 | state_dict = state_dict['model'] 619 | for k, v in state_dict.items(): 620 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 621 | # For old engine that I trained prior to conv based patchification 622 | O, I, H, W = model.patch_embed.proj.weight.shape 623 | v = v.reshape(O, -1, H, W) 624 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 625 | # To resize pos embedding when using model at different size from pretrained weights 626 | v = resize_pos_embed( 627 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 628 | out_dict[k] = v 629 | return out_dict 630 | 631 | 632 | def _create_vision_transformer(variant, pretrained=False, **kwargs): 633 | if kwargs.get('features_only', None): 634 | raise RuntimeError('features_only not implemented for Vision Transformer engine.') 635 | 636 | # NOTE this extra code to support handling of repr size for in21k pretrained engine 637 | pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) 638 | default_num_classes = pretrained_cfg['num_classes'] 639 | num_classes = kwargs.get('num_classes', default_num_classes) 640 | repr_size = kwargs.pop('representation_size', None) 641 | if repr_size is not None and num_classes != default_num_classes: 642 | # Remove representation layer if fine-tuning. This may not always be the desired action, 643 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 644 | _logger.warning("Removing representation layer for fine-tuning.") 645 | repr_size = None 646 | 647 | model = build_model_with_cfg( 648 | VisionTransformer, variant, pretrained, 649 | pretrained_cfg=pretrained_cfg, 650 | representation_size=repr_size, 651 | pretrained_filter_fn=checkpoint_filter_fn, 652 | pretrained_custom_load='npz' in pretrained_cfg['url'], 653 | **kwargs) 654 | return model 655 | 656 | 657 | @register_model 658 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 659 | """ ViT-Tiny (Vit-Ti/16) 660 | """ 661 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 662 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 663 | return model 664 | 665 | 666 | @register_model 667 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 668 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 669 | """ 670 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 671 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 672 | return model 673 | 674 | 675 | @register_model 676 | def vit_small_patch32_224(pretrained=False, **kwargs): 677 | """ ViT-Small (ViT-S/32) 678 | """ 679 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 680 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) 681 | return model 682 | 683 | 684 | @register_model 685 | def vit_small_patch32_384(pretrained=False, **kwargs): 686 | """ ViT-Small (ViT-S/32) at 384x384. 687 | """ 688 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 689 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 690 | return model 691 | 692 | 693 | @register_model 694 | def vit_small_patch16_224(pretrained=False, **kwargs): 695 | """ ViT-Small (ViT-S/16) 696 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 697 | """ 698 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 699 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 700 | return model 701 | 702 | 703 | @register_model 704 | def vit_small_patch16_384(pretrained=False, **kwargs): 705 | """ ViT-Small (ViT-S/16) 706 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 707 | """ 708 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 709 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 710 | return model 711 | 712 | 713 | @register_model 714 | def vit_base_patch32_224(pretrained=False, **kwargs): 715 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 716 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. 717 | """ 718 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 719 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 720 | return model 721 | 722 | 723 | @register_model 724 | def vit_base2_patch32_256(pretrained=False, **kwargs): 725 | """ ViT-Base (ViT-B/32) 726 | # FIXME experiment 727 | """ 728 | model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) 729 | model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) 730 | return model 731 | 732 | 733 | @register_model 734 | def vit_base_patch32_384(pretrained=False, **kwargs): 735 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 736 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 737 | """ 738 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 739 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 740 | return model 741 | 742 | 743 | @register_model 744 | def vit_base_patch16_224(pretrained=False, **kwargs): 745 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 746 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 747 | """ 748 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 749 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 750 | return model 751 | 752 | 753 | @register_model 754 | def vit_base_patch16_384(pretrained=False, **kwargs): 755 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 756 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 757 | """ 758 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 759 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 760 | return model 761 | 762 | 763 | @register_model 764 | def vit_base_patch8_224(pretrained=False, **kwargs): 765 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 766 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 767 | """ 768 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 769 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) 770 | return model 771 | 772 | 773 | @register_model 774 | def vit_large_patch32_224(pretrained=False, **kwargs): 775 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 776 | """ 777 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 778 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 779 | return model 780 | 781 | 782 | @register_model 783 | def vit_large_patch32_384(pretrained=False, **kwargs): 784 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 785 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 786 | """ 787 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 788 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 789 | return model 790 | 791 | 792 | @register_model 793 | def vit_large_patch16_224(pretrained=False, **kwargs): 794 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 795 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 796 | """ 797 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 798 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 799 | return model 800 | 801 | 802 | @register_model 803 | def vit_large_patch16_384(pretrained=False, **kwargs): 804 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 805 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 806 | """ 807 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 808 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 809 | return model 810 | 811 | 812 | @register_model 813 | def vit_large_patch14_224(pretrained=False, **kwargs): 814 | """ ViT-Large model (ViT-L/14) 815 | """ 816 | model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) 817 | model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) 818 | return model 819 | 820 | 821 | @register_model 822 | def vit_huge_patch14_224(pretrained=False, **kwargs): 823 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 824 | """ 825 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) 826 | model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) 827 | return model 828 | 829 | 830 | @register_model 831 | def vit_giant_patch14_224(pretrained=False, **kwargs): 832 | """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 833 | """ 834 | model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) 835 | model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) 836 | return model 837 | 838 | 839 | @register_model 840 | def vit_gigantic_patch14_224(pretrained=False, **kwargs): 841 | """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 842 | """ 843 | model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) 844 | model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) 845 | return model 846 | 847 | 848 | @register_model 849 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): 850 | """ ViT-Tiny (Vit-Ti/16). 851 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 852 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 853 | """ 854 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 855 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 856 | return model 857 | 858 | 859 | @register_model 860 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs): 861 | """ ViT-Small (ViT-S/16) 862 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 863 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 864 | """ 865 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 866 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 867 | return model 868 | 869 | 870 | @register_model 871 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs): 872 | """ ViT-Small (ViT-S/16) 873 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 874 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 875 | """ 876 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 877 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 878 | return model 879 | 880 | 881 | @register_model 882 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 883 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 884 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 885 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 886 | """ 887 | model_kwargs = dict( 888 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 889 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 890 | return model 891 | 892 | 893 | @register_model 894 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 895 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 896 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 897 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 898 | """ 899 | model_kwargs = dict( 900 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 901 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 902 | return model 903 | 904 | 905 | @register_model 906 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs): 907 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 908 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 909 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 910 | """ 911 | model_kwargs = dict( 912 | patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 913 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) 914 | return model 915 | 916 | 917 | @register_model 918 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 919 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 920 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 921 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 922 | """ 923 | model_kwargs = dict( 924 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 925 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 926 | return model 927 | 928 | 929 | @register_model 930 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 931 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 932 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 933 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 934 | """ 935 | model_kwargs = dict( 936 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 937 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 938 | return model 939 | 940 | 941 | @register_model 942 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 943 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 944 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 945 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 946 | """ 947 | model_kwargs = dict( 948 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) 949 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 950 | return model 951 | 952 | 953 | @register_model 954 | def vit_base_patch16_224_sam(pretrained=False, **kwargs): 955 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 956 | """ 957 | # NOTE original SAM weights release worked with representation_size=768 958 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 959 | model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) 960 | return model 961 | 962 | 963 | @register_model 964 | def vit_base_patch32_224_sam(pretrained=False, **kwargs): 965 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 966 | """ 967 | # NOTE original SAM weights release worked with representation_size=768 968 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 969 | model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) 970 | return model 971 | 972 | 973 | @register_model 974 | def vit_small_patch16_224_dino(pretrained=False, **kwargs): 975 | """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 976 | """ 977 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 978 | model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) 979 | return model 980 | 981 | 982 | @register_model 983 | def vit_small_patch8_224_dino(pretrained=False, **kwargs): 984 | """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 985 | """ 986 | model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) 987 | model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) 988 | return model 989 | 990 | 991 | @register_model 992 | def vit_base_patch16_224_dino(pretrained=False, **kwargs): 993 | """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 994 | """ 995 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 996 | model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) 997 | return model 998 | 999 | 1000 | @register_model 1001 | def vit_base_patch8_224_dino(pretrained=False, **kwargs): 1002 | """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 1003 | """ 1004 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 1005 | model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) 1006 | return model 1007 | 1008 | 1009 | @register_model 1010 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): 1011 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 1012 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 1013 | """ 1014 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 1015 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) 1016 | return model 1017 | 1018 | 1019 | @register_model 1020 | def vit_base_patch16_224_miil(pretrained=False, **kwargs): 1021 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 1022 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 1023 | """ 1024 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 1025 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) 1026 | return model 1027 | 1028 | 1029 | @register_model 1030 | def vit_small_patch16_36x1_224(pretrained=False, **kwargs): 1031 | """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. 1032 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1033 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. 1034 | """ 1035 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) 1036 | model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) 1037 | return model 1038 | 1039 | 1040 | @register_model 1041 | def vit_small_patch16_18x2_224(pretrained=False, **kwargs): 1042 | """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. 1043 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1044 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. 1045 | """ 1046 | model_kwargs = dict( 1047 | patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) 1048 | model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) 1049 | return model 1050 | 1051 | 1052 | @register_model 1053 | def vit_base_patch16_18x2_224(pretrained=False, **kwargs): 1054 | """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. 1055 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 1056 | """ 1057 | model_kwargs = dict( 1058 | patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) 1059 | model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) 1060 | return model --------------------------------------------------------------------------------