├── models
├── __init__.py
├── convit.py
└── vit.py
├── utils
├── __init__.py
├── datautils
│ ├── __init__.py
│ └── core50
│ │ └── core50data.py
├── toolkit.py
├── data_manager.py
└── data.py
├── overview.jpg
├── LICENSE
├── environment.yaml
├── main.py
├── README.md
├── networks.py
└── trainer.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/datautils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iamwangyabin/ESN/HEAD/overview.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Wang Yabin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/toolkit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def count_parameters(model, trainable=False):
7 | if trainable:
8 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
9 | return sum(p.numel() for p in model.parameters())
10 |
11 |
12 | def tensor2numpy(x):
13 | return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()
14 |
15 |
16 | def target2onehot(targets, n_classes):
17 | onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
18 | onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.)
19 | return onehot
20 |
21 |
22 | def makedirs(path):
23 | if not os.path.exists(path):
24 | os.makedirs(path)
25 |
26 |
27 | def accuracy(y_pred, y_true, nb_old, increment=10):
28 | assert len(y_pred) == len(y_true), 'Data length error.'
29 | all_acc = {}
30 | all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2)
31 |
32 | # Grouped accuracy
33 | for class_id in range(0, np.max(y_true), increment):
34 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0]
35 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0'))
36 | all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)
37 |
38 | # Old accuracy
39 | idxes = np.where(y_true < nb_old)[0]
40 | all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes),
41 | decimals=2)
42 |
43 | # New accuracy
44 | idxes = np.where(y_true >= nb_old)[0]
45 | all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)
46 |
47 | return all_acc
48 |
49 |
50 | def split_images_labels(imgs):
51 | # split trainset.imgs in ImageFolder
52 | images = []
53 | labels = []
54 | for item in imgs:
55 | images.append(item[0])
56 | labels.append(item[1])
57 |
58 | return np.array(images), np.array(labels)
59 |
60 |
61 |
62 |
63 | def accuracy_binary(y_pred, y_true, nb_old, increment=2):
64 | assert len(y_pred) == len(y_true), 'Data length error.'
65 | all_acc = {}
66 | all_acc['total'] = np.around((y_pred%2 == y_true%2).sum()*100 / len(y_true), decimals=2)
67 |
68 | # Grouped accuracy
69 | for class_id in range(0, np.max(y_true), increment):
70 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0]
71 | label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0'))
72 | all_acc[label] = np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes), decimals=2)
73 |
74 | # Old accuracy
75 | idxes = np.where(y_true < nb_old)[0]
76 | # all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes),decimals=2)
77 | all_acc['old'] = 0 if len(idxes) == 0 else np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes),decimals=2)
78 |
79 | # New accuracy
80 | idxes = np.where(y_true >= nb_old)[0]
81 | all_acc['new'] = np.around(((y_pred[idxes]%2) == (y_true[idxes]%2)).sum()*100 / len(idxes), decimals=2)
82 |
83 | return all_acc
84 |
85 |
86 |
87 | def accuracy_domain(y_pred, y_true, nb_old, increment=10):
88 | assert len(y_pred) == len(y_true), 'Data length error.'
89 | all_acc = {}
90 | all_acc['total'] = np.around((y_pred%345 == y_true%345).sum()*100 / len(y_true), decimals=2)
91 | return all_acc
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: sp
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - brotlipy=0.7.0=py38h27cfd23_1003
10 | - bzip2=1.0.8=h7b6447c_0
11 | - ca-certificates=2022.10.11=h06a4308_0
12 | - certifi=2022.9.24=py38h06a4308_0
13 | - cffi=1.15.1=py38h74dc2b5_0
14 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
15 | - cryptography=37.0.1=py38h9ce1e76_0
16 | - cudatoolkit=11.3.1=h2bc3f7f_2
17 | - ffmpeg=4.3=hf484d3e_0
18 | - freetype=2.11.0=h70c0345_0
19 | - giflib=5.2.1=h7b6447c_0
20 | - gmp=6.2.1=h295c915_3
21 | - gnutls=3.6.15=he1e5248_0
22 | - idna=3.3=pyhd3eb1b0_0
23 | - intel-openmp=2021.4.0=h06a4308_3561
24 | - jpeg=9e=h7f8727e_0
25 | - lame=3.100=h7b6447c_0
26 | - lcms2=2.12=h3be6417_0
27 | - ld_impl_linux-64=2.38=h1181459_1
28 | - lerc=3.0=h295c915_0
29 | - libdeflate=1.8=h7f8727e_5
30 | - libffi=3.3=he6710b0_2
31 | - libgcc-ng=11.2.0=h1234567_1
32 | - libgomp=11.2.0=h1234567_1
33 | - libiconv=1.16=h7f8727e_2
34 | - libidn2=2.3.2=h7f8727e_0
35 | - libpng=1.6.37=hbc83047_0
36 | - libstdcxx-ng=11.2.0=h1234567_1
37 | - libtasn1=4.16.0=h27cfd23_0
38 | - libtiff=4.4.0=hecacb30_0
39 | - libunistring=0.9.10=h27cfd23_0
40 | - libuv=1.40.0=h7b6447c_0
41 | - libwebp=1.2.2=h55f646e_0
42 | - libwebp-base=1.2.2=h7f8727e_0
43 | - lz4-c=1.9.3=h295c915_1
44 | - mkl=2021.4.0=h06a4308_640
45 | - mkl-service=2.4.0=py38h7f8727e_0
46 | - mkl_fft=1.3.1=py38hd3c417c_0
47 | - mkl_random=1.2.2=py38h51133e4_0
48 | - ncurses=6.3=h5eee18b_3
49 | - nettle=3.7.3=hbbd107a_1
50 | - numpy=1.23.1=py38h6c91a56_0
51 | - numpy-base=1.23.1=py38ha15fc14_0
52 | - openh264=2.1.1=h4ff587b_0
53 | - openssl=1.1.1q=h7f8727e_0
54 | - pillow=9.2.0=py38hace64e9_1
55 | - pip=22.1.2=py38h06a4308_0
56 | - pycparser=2.21=pyhd3eb1b0_0
57 | - pyopenssl=22.0.0=pyhd3eb1b0_0
58 | - pysocks=1.7.1=py38h06a4308_0
59 | - python=3.8.13=h12debd9_0
60 | - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
61 | - pytorch-mutex=1.0=cuda
62 | - readline=8.1.2=h7f8727e_1
63 | - requests=2.28.1=py38h06a4308_0
64 | - setuptools=63.4.1=py38h06a4308_0
65 | - six=1.16.0=pyhd3eb1b0_1
66 | - sqlite=3.39.2=h5082296_0
67 | - tk=8.6.12=h1ccaba5_0
68 | - torchaudio=0.11.0=py38_cu113
69 | - torchvision=0.12.0=py38_cu113
70 | - typing_extensions=4.3.0=py38h06a4308_0
71 | - urllib3=1.26.11=py38h06a4308_0
72 | - wheel=0.37.1=pyhd3eb1b0_0
73 | - xz=5.2.5=h7f8727e_1
74 | - zlib=1.2.12=h5eee18b_3
75 | - zstd=1.5.2=ha4553b6_0
76 | - pip:
77 | - aiohttp==3.8.3
78 | - aiosignal==1.3.1
79 | - antlr4-python3-runtime==4.9.3
80 | - async-timeout==4.0.2
81 | - attrs==22.2.0
82 | - click==8.1.3
83 | - contourpy==1.0.5
84 | - cycler==0.11.0
85 | - docker-pycreds==0.4.0
86 | - einops==0.6.0
87 | - fonttools==4.37.3
88 | - frozenlist==1.3.3
89 | - fsspec==2022.11.0
90 | - ftfy==6.1.1
91 | - gitdb==4.0.9
92 | - gitpython==3.1.27
93 | - joblib==1.2.0
94 | - kiwisolver==1.4.4
95 | - lightning-utilities==0.5.0
96 | - matplotlib==3.6.0
97 | - multidict==6.0.4
98 | - omegaconf==2.3.0
99 | - opencv-python==4.6.0.66
100 | - packaging==21.3
101 | - pandas==1.5.0
102 | - pathtools==0.1.2
103 | - pot==0.8.2
104 | - promise==2.3
105 | - protobuf==3.20.1
106 | - psutil==5.9.2
107 | - pyparsing==3.0.9
108 | - python-dateutil==2.8.2
109 | - pytorch-lightning==1.8.6
110 | - pytz==2022.2.1
111 | - pyyaml==6.0
112 | - quadprog==0.1.11
113 | - regex==2022.9.13
114 | - scikit-learn==0.23.1
115 | - scipy==1.9.1
116 | - seaborn==0.12.0
117 | - sentry-sdk==1.9.8
118 | - setproctitle==1.3.2
119 | - shortuuid==1.0.9
120 | - sklearn==0.0
121 | - smmap==5.0.0
122 | - tensorboardx==2.5.1
123 | - threadpoolctl==3.1.0
124 | - timm==0.6.7
125 | - torchmetrics==0.11.0
126 | - tqdm==4.64.1
127 | - wandb==0.13.3
128 | - wcwidth==0.2.5
129 | - yarl==1.8.2
130 | prefix: /home/wangyabin/miniconda3/envs/sp
131 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | import time
5 | import argparse
6 | import wandb
7 |
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 |
11 | from networks import incremental_vitood
12 | from trainer import training, eval
13 | from utils.data_manager import DataManager
14 |
15 | def setup_parser():
16 | parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
17 | parser.add_argument('--method', default='ESN', type=str, help='str for comment')
18 | parser.add_argument('--model_name', default='vit', type=str, help='str for comment')
19 | parser.add_argument('--dataset', default='5datasets_vit', type=str, help='cifar100_vit, 5datasets_vit, core50')
20 | parser.add_argument('--init_cls', default=10, type=int, help='str for comment')
21 | parser.add_argument('--inc_cls', default=10, type=int, help='str for comment')
22 | parser.add_argument('--shuffle', action='store_false', help='false is l2p, which is not shuffle')
23 | parser.add_argument('--random_seed', default=1993, type=int, help='str for comment')
24 | parser.add_argument('--training_device', default="2", type=str, help='str for comment')
25 | parser.add_argument('--max_epochs', default=50, type=int, help='str for comment')
26 | parser.add_argument('--lr', default=0.01, type=float, help='Set learning rate')
27 |
28 | parser.add_argument('--using_prompt', default=True, type=bool, help='str for comment')
29 | parser.add_argument('--anchor_energy', default=-10, type=float, help='str for comment')
30 | parser.add_argument('--energy_beta', default=1, type=float, help='str for comment')
31 | parser.add_argument('--lamda', default=0.1, type=float, help='0 means do not use energy alignment')
32 | parser.add_argument('--temptures', default=20, type=int, help='max temperature')
33 | parser.add_argument('--voting', default=True, type=bool, help='wither or not to voting')
34 |
35 | parser.add_argument('--dil', default=False, type=bool, help='For domain incremental learning evaluation')
36 | parser.add_argument('--max_cls', default=2, type=int, help='For domain incremental learning evaluation')
37 | parser.add_argument('--notes', default='', type=str, help='str for comment')
38 |
39 | return parser
40 |
41 | def _set_random():
42 | torch.manual_seed(1)
43 | torch.cuda.manual_seed(1)
44 | torch.cuda.manual_seed_all(1)
45 | torch.backends.cudnn.deterministic = True
46 | torch.backends.cudnn.benchmark = False
47 |
48 | def main():
49 | args = setup_parser().parse_args()
50 | os.environ["CUDA_VISIBLE_DEVICES"] = args.training_device
51 | _set_random()
52 | args.localtime = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
53 |
54 | if args.dataset == "5datasets_vit" or args.dataset == "core50":
55 | args.shuffle=False
56 | data_manager = DataManager(args.dataset, args.shuffle, args.random_seed, args.init_cls, args.inc_cls, args=vars(args))
57 | args.class_order = data_manager._class_order
58 |
59 | wandb.init(project="ESN",
60 | name='{}_{}_{}_{}_{}_'.format(args.method, args.model_name, args.dataset, args.init_cls, args.inc_cls) + args.localtime,
61 | save_code=True, group='{}_{}'.format(args.dataset, args.model_name), notes=args.notes, config=args)
62 |
63 | all_tabs, all_classifiers, all_tokens, accuracy_log, vitpromptlist= [], [], [], [], []
64 | vitprompt = None
65 | _known_classes=0
66 |
67 | for taskid in range(data_manager.nb_tasks):
68 | print("current task: {}".format(taskid))
69 | _total_classes = _known_classes + data_manager.get_task_size(taskid)
70 | current_data = np.arange(_known_classes, _total_classes)
71 | train_dataset = data_manager.get_dataset(current_data, source='train', mode='train')
72 |
73 | if args.dataset == "core50":
74 | test_dataset = data_manager.get_dataset(np.arange(0, data_manager.get_task_size(0)), source='test', mode='test')
75 | else:
76 | test_dataset = data_manager.get_dataset(current_data, source='test', mode='test')
77 |
78 | train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8, persistent_workers=True, pin_memory=True)
79 | test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8)
80 |
81 | model, eval_result = training(args, incremental_vitood, taskid, train_loader, test_loader,
82 | known_classes=_known_classes, vitprompt=vitprompt, numclass=data_manager.get_task_size(taskid))
83 |
84 | all_tabs.append(copy.deepcopy(model.tabs).cpu())
85 | all_classifiers.append(copy.deepcopy(model.classifiers).cpu())
86 | all_tokens.append(copy.deepcopy(model.task_tokens).cpu())
87 | vitprompt = copy.deepcopy(model.vitprompt).cpu()
88 | vitpromptlist.append(copy.deepcopy(model.vitprompt).cpu())
89 |
90 | del model
91 |
92 | _known_classes = _total_classes
93 |
94 | assembles = {'all_tabs': all_tabs, 'all_classifiers': all_classifiers, 'all_tokens': all_tokens, 'vitpromptlist':vitpromptlist}
95 | torch.save(assembles, './checkpoints/'+wandb.run.name+'.pth')
96 |
97 | eval(args, './checkpoints/'+wandb.run.name+'.pth', data_manager)
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference
2 |
3 |
4 | This is the official implementation of our AAAI 2023 paper "Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference".
5 | This paper focuses on the prevalent stage interference and stage performance imbalance of incremental learning. To avoid obvious stage learning bottlenecks, we propose a new incremental learning framework, which leverages a series of stage-isolated classifiers to perform the learning task at each stage, without interference from others. To be concrete, to aggregate multiple stage classifiers as a uniform one impartially, we first introduce a temperature-controlled energy metric for indicating the confidence score levels of the stage classifiers. We then propose an anchor-based energy self-normalization strategy to ensure the stage classifiers work at the same energy level. Finally, we design a voting-based inference augmentation strategy for robust inference. The proposed method is rehearsal-free and can work for almost all incremental learning scenarios. We evaluate the proposed method on four large datasets. Extensive results demonstrate the superiority of the proposed method in setting up new state-of-the-art overall performance.
6 |
7 |
8 |
9 | **Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference**
10 | Yabin Wang*, Zhiheng Ma*, Zhiwu Huang, Yaowei Wang, Zhou Su, Xiaopeng Hong. 2023 Proceedings of the AAAI Conference on Artificial Intelligence (AAAI 23).
11 | [[Paper]](https://arxiv.org/abs/2211.15969)
12 |
13 | ## Introduction
14 |
15 |
16 |
17 |
18 |
19 | In this paper, we propose anchor-based energy self-normalization for stage classifiers.
20 | The classifiers of the current and the previous stages, $f_{\eta_s}$ and $f_{\eta_{s-1}}$, are aligned sequentially by restricting their energies around the anchor
21 |
22 |
23 | 
24 |
25 |
26 |
27 | ## Enviroment setup
28 | Create the virtual environment for ESN.
29 | ```python
30 | conda env create -f environment.yaml
31 | ```
32 | After this, you will get a new environment **esn** that can conduct ESN experiments.
33 | Run `conda activate esn` to activate.
34 |
35 | Note that only NVIDIA GPUs are supported for now, and we use NVIDIA RTX 3090.
36 |
37 | ## Dataset preparation
38 | Please refer to the following links to download three standard incremental learning benchmark datasets.
39 | [CIFAR-100] Auto Download
40 | [CORe50](https://vlomonaco.github.io/core50/index.html#dataset)
41 | [DomainNet](http://ai.bu.edu/M3SDA/)
42 | [5-Datasets]
43 |
44 | Unzip the downloaded files, and you will get the following folders.
45 |
46 | ```
47 | core50
48 | └── core50_128x128
49 | ├── labels.pkl
50 | ├── LUP.pkl
51 | ├── paths.pkl
52 | ├── s1
53 | ├── s2
54 | ├── s3
55 | ...
56 | ```
57 |
58 | ```
59 | domainnet
60 | ├── clipart
61 | │ ├── aircraft_carrier
62 | │ ├── airplane
63 | │ ... ...
64 | ├── clipart_test.txt
65 | ├── clipart_train.txt
66 | ├── infograph
67 | │ ├── aircraft_carrier
68 | │ ├── airplane
69 | │ ... ...
70 | ├── infograph_test.txt
71 | ├── infograph_train.txt
72 | ├── painting
73 | │ ├── aircraft_carrier
74 | │ ├── airplane
75 | │ ... ...
76 | ... ...
77 | ```
78 |
79 | ## Training:
80 |
81 | Please change the `data_path` in `utils/data.py` to the locations of the datasets.
82 |
83 | ### Split-CIFAR100:
84 | ```
85 | python main.py --dataset cifar100_vit --max_epochs 30 --init_cls 10 --inc_cls 10 --shuffle
86 | ```
87 |
88 | ### Split-DomainNet:
89 | ```
90 | python main.py --dataset domainnet --max_epochs 10 --init_cls 20 --inc_cls 20 --shuffle
91 | ```
92 |
93 | ### CORe50:
94 | ```
95 | python main.py --dataset core50 --max_epochs 10 --init_cls 50 --inc_cls 50 --dil True --max_cls 50
96 | ```
97 |
98 | ### 5-Datasets:
99 | ```
100 | python main.py --dataset 5datasets_vit --max_epochs 10 --init_cls 10 --inc_cls 10
101 | ```
102 |
103 | ## Model Zoo:
104 |
105 | Pretrained models are available [here](https://drive.google.com/drive/folders/1D8qv1klFXePl-aQr5T8NVpoAf3nmxUld?usp=sharing).
106 |
107 | We assume the downloaded weights are located under the `checkpoints` directory.
108 |
109 | Otherwise, you may need to change the corresponding paths in the scripts.
110 |
111 |
112 | [comment]: <> (## Results)
113 |
114 | [comment]: <> ()
115 |
116 | [comment]: <> ()
117 |
118 | [comment]: <> ()
119 |
120 | ## License
121 |
122 | Please check the MIT [license](./LICENSE) that is listed in this repository.
123 |
124 | ## Acknowledgments
125 |
126 | We thank the following repos providing helpful components/functions in our work.
127 |
128 | - [S-Prompts](https://github.com/iamwangyabin/S-Prompts)
129 | - [PyCIL](https://github.com/G-U-N/PyCIL)
130 |
131 | ## Citation
132 |
133 | If you use any content of this repo for your work, please cite the following bib entry:
134 | ```
135 | @article{wang2022isolation,
136 | title={Isolation and Impartial Aggregation: A Paradigm of Incremental Learning without Interference},
137 | author={Wang, Yabin and Ma, Zhiheng and Huang, Zhiwu and Wang, Yaowei and Su, Zhou and Hong, Xiaopeng},
138 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
139 | year={2023}
140 | }
141 | ```
142 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import optim
6 | from torch.nn import functional as F
7 | import pytorch_lightning as pl
8 |
9 | from models.vit import VisionTransformer, PatchEmbed, Block,resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn
10 | from models.convit import ClassAttention
11 | from models.convit import Block as ConBlock
12 |
13 |
14 | class ViT_KPrompts(VisionTransformer):
15 | def __init__(
16 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
17 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
18 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
19 | embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
20 |
21 | super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
22 | embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size,
23 | drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values,
24 | embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn)
25 |
26 | def forward(self, x, instance_tokens=None, returnbeforepool=False, **kwargs):
27 | x = self.patch_embed(x)
28 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
29 |
30 | if instance_tokens is not None:
31 | instance_tokens = instance_tokens.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
32 |
33 | x = x + self.pos_embed.to(x.dtype)
34 | if instance_tokens is not None:
35 | x = torch.cat([x[:,:1,:], instance_tokens, x[:,1:,:]], dim=1)
36 | x = self.pos_drop(x)
37 | x = self.blocks(x)
38 | if returnbeforepool == True:
39 | return x
40 | x = self.norm(x)
41 | if self.global_pool:
42 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
43 | x = self.fc_norm(x)
44 | return x
45 |
46 | def _create_vision_transformer(variant, pretrained=False, **kwargs):
47 | if kwargs.get('features_only', None):
48 | raise RuntimeError('features_only not implemented for Vision Transformer models.')
49 |
50 | # NOTE this extra code to support handling of repr size for in21k pretrained models
51 | pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
52 | default_num_classes = pretrained_cfg['num_classes']
53 | num_classes = kwargs.get('num_classes', default_num_classes)
54 | repr_size = kwargs.pop('representation_size', None)
55 | if repr_size is not None and num_classes != default_num_classes:
56 | repr_size = None
57 |
58 | model = build_model_with_cfg(
59 | ViT_KPrompts, variant, pretrained,
60 | pretrained_cfg=pretrained_cfg,
61 | representation_size=repr_size,
62 | pretrained_filter_fn=checkpoint_filter_fn,
63 | pretrained_custom_load='npz' in pretrained_cfg['url'],
64 | **kwargs)
65 | return model
66 |
67 |
68 |
69 | class incremental_vitood(pl.LightningModule):
70 | def __init__(self, num_cls, lr, max_epoch, weight_decay, known_classes, freezep, using_prompt, anchor_energy=-10,
71 | lamda=0.1, energy_beta=1):
72 | super().__init__()
73 | self.save_hyperparameters()
74 |
75 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
76 | self.image_encoder =_create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs)
77 |
78 | self.classifiers = nn.Linear(self.image_encoder.embed_dim, self.hparams.num_cls, bias=True)
79 | self.tabs = ConBlock(dim=self.image_encoder.embed_dim, num_heads=12, mlp_ratio=0.5, qkv_bias=True,
80 | qk_scale=None, drop=0.,attn_drop=0., norm_layer=nn.LayerNorm, attention_type=ClassAttention)
81 | self.task_tokens = copy.deepcopy(self.image_encoder.cls_token)
82 | self.vitprompt = nn.Linear(self.image_encoder.embed_dim, 100, bias=False)
83 | self.pre_vitprompt = None
84 |
85 | for name, param in self.image_encoder.named_parameters():
86 | param.requires_grad_(False)
87 |
88 | if self.hparams.freezep:
89 | for name, param in self.vitprompt.named_parameters():
90 | param.requires_grad_(False)
91 |
92 | def forward(self, image):
93 | if self.hparams.using_prompt:
94 | image_features = self.image_encoder(image, instance_tokens=self.vitprompt, returnbeforepool=True, )
95 | else:
96 | image_features = self.image_encoder(image, returnbeforepool=True)
97 |
98 | B = image_features.shape[0]
99 | task_token = self.task_tokens.expand(B, -1, -1)
100 | task_token, attn, v = self.tabs(torch.cat((task_token, image_features), dim=1), mask_heads=None)
101 | logits = self.classifiers(task_token[:, 0])
102 |
103 | return logits
104 |
105 | def configure_optimizers(self):
106 | optparams = filter(lambda p: p.requires_grad, self.parameters())
107 | optimizer = optim.SGD(optparams, momentum=0.9,lr=self.hparams.lr,weight_decay=self.hparams.weight_decay)
108 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=self.hparams.max_epoch)
109 | return [optimizer], [scheduler]
110 |
111 | def _calculate_loss(self, batch, mode='train'):
112 | _, images, labels = batch
113 | labels = labels-self.hparams.known_classes
114 |
115 | if self.hparams.using_prompt:
116 | image_features = self.image_encoder(images, instance_tokens=self.vitprompt.weight, returnbeforepool=True, )
117 | else:
118 | image_features = self.image_encoder(images, returnbeforepool=True)
119 | B = image_features.shape[0]
120 | task_token = self.task_tokens.expand(B, -1, -1)
121 | task_token, attn, v = self.tabs(torch.cat((task_token, image_features), dim=1), mask_heads=None)
122 | logits = self.classifiers(task_token[:, 0])
123 | loss = F.cross_entropy(logits, labels)
124 |
125 | output_div_t = -1.0 * self.hparams.energy_beta * logits
126 | output_logsumexp = torch.logsumexp(output_div_t, dim=1, keepdim=False)
127 | free_energy = -1.0 * output_logsumexp / self.hparams.energy_beta
128 | align_loss = self.hparams.lamda * ((free_energy - self.hparams.anchor_energy) ** 2).mean()
129 |
130 | if self.pre_vitprompt is not None:
131 | pre_feature = self.image_encoder(images, instance_tokens=self.pre_vitprompt.weight, returnbeforepool=True, )
132 | kdloss = nn.MSELoss()(pre_feature.detach(), image_features)
133 | else:
134 | kdloss = 0
135 |
136 | loss = loss+align_loss+kdloss
137 |
138 | acc = (logits.argmax(dim=-1) == labels).float().mean()
139 | self.log("%s_loss" % mode, loss)
140 | self.log("%s_acc" % mode, acc)
141 | return loss
142 |
143 | def validation_step(self, batch, batch_idx):
144 | loss = self._calculate_loss(batch, mode='val')
145 |
146 | def training_step(self, batch, batch_idx):
147 | loss = self._calculate_loss(batch, mode='train')
148 | return loss
149 |
150 | def val_step(self, batch, batch_idx):
151 | self._calculate_loss(batch, mode='val')
152 |
153 | def test_step(self, batch, batch_idx):
154 | self._calculate_loss(batch, mode='test')
155 |
156 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import wandb
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset, DataLoader
6 | import pytorch_lightning as pl
7 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
8 | from networks import _create_vision_transformer
9 |
10 | def training(args, prototype_mode, taskid, train_loader, test_loader, known_classes=0, vitprompt=None, numclass=10):
11 | trainer = pl.Trainer(default_root_dir='./checkpoints/'+wandb.run.name, accelerator="gpu", devices=1,
12 | max_epochs=args.max_epochs, progress_bar_refresh_rate=1,
13 | callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="train_acc"),
14 | LearningRateMonitor("epoch"),])
15 |
16 | if taskid == 0:
17 | model = prototype_mode(num_cls=numclass, lr=args.lr, max_epoch=args.max_epochs, weight_decay=0.0005,
18 | known_classes=known_classes, freezep=False, using_prompt=args.using_prompt,
19 | anchor_energy=args.anchor_energy, lamda=args.lamda, energy_beta=args.energy_beta)
20 | else:
21 | model = prototype_mode(num_cls=numclass, lr=args.lr, max_epoch=args.max_epochs, weight_decay=0.0005,
22 | known_classes=known_classes, freezep=True, using_prompt=args.using_prompt,
23 | anchor_energy=args.anchor_energy, lamda=args.lamda, energy_beta=args.energy_beta)
24 | model.vitprompt = vitprompt
25 |
26 |
27 | trainer.fit(model, train_loader)
28 | if args.dataset == "core50":
29 | val_result = [{'test_acc':0}]
30 | else:
31 | val_result = trainer.test(model, dataloaders=test_loader, verbose=False)
32 | print(val_result)
33 | return model, val_result
34 |
35 |
36 | def eval(args, load_path, datamanage):
37 | assembles = torch.load(load_path, map_location=torch.device('cpu'))
38 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
39 | ptvit = _create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs)
40 |
41 | device = 'cuda:0'
42 | ptvit = ptvit.to(device)
43 | ptvit.eval()
44 | all_tabs = assembles['all_tabs']
45 | all_classifiers = assembles['all_classifiers']
46 | all_tokens = assembles['all_tokens']
47 | vitpromptlist = assembles['vitpromptlist']
48 |
49 | all_tabs = [i.to(device) for i in all_tabs]
50 | all_classifiers = [i.to(device) for i in all_classifiers]
51 | all_tokens = [i.to(device) for i in all_tokens]
52 | vitpromptlist = [i.to(device) for i in vitpromptlist]
53 |
54 | _known_classes=0
55 | # fast mode
56 | candidata_temperatures = [0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
57 | # slow mode
58 | # candidata_temperatures = [i/1000 for i in range(1,1000,1)]
59 | select_temperature = [0.001] # for first stage with no former data
60 | accuracy_table, accs = [], []
61 | print("Testing...")
62 | for taskid in range(datamanage.nb_tasks):
63 | _total_classes = _known_classes + datamanage.get_task_size(taskid)
64 | test_till_now_dataset = datamanage.get_dataset(np.arange(0, _total_classes), source='test', mode='test')
65 | test_till_now_loader = DataLoader(test_till_now_dataset, batch_size=128, shuffle=False, num_workers=8)
66 |
67 | if taskid > 0:
68 | classifiers = all_classifiers[:taskid+1]
69 | task_tokens = all_tokens[:taskid+1]
70 | tabs = all_tabs[:taskid+1]
71 | current_dataset = datamanage.get_dataset(np.arange(_known_classes, _total_classes), source='train', mode='test')
72 | current_dataloader = DataLoader(current_dataset, batch_size=128, shuffle=False, num_workers=8)
73 | all_energies = {i: [] for i in candidata_temperatures}
74 | with torch.no_grad():
75 | for _, (_, inputs, targets) in enumerate(current_dataloader):
76 | inputs = inputs.to(device)
77 | targets = targets.to(device)
78 | energys = {i: [] for i in candidata_temperatures}
79 | image_features = ptvit(inputs, instance_tokens=vitpromptlist[taskid].weight, returnbeforepool=True)
80 |
81 | B = image_features.shape[0]
82 | for idx, fc in enumerate(classifiers):
83 | task_token = task_tokens[idx].expand(B, -1, -1)
84 | task_token, attn, v = tabs[idx](torch.cat((task_token, image_features), dim=1), mask_heads=None)
85 | task_token = task_token[:, 0]
86 | logit = fc(task_token)
87 |
88 | for tem in candidata_temperatures:
89 | energys[tem].append(torch.logsumexp(logit / tem, axis=-1))
90 | energys = {i: torch.stack(energys[i]).T for i in candidata_temperatures}
91 | for i in candidata_temperatures:
92 | all_energies[i].append(energys[i])
93 |
94 | all_energies = {i: torch.cat(all_energies[i]) for i in candidata_temperatures}
95 | seperation_accuracy = []
96 | for i in candidata_temperatures:
97 | seperation_accuracy.append((sum(all_energies[i].max(1)[1]==(taskid))/len(all_energies[i])).item())
98 | select_temperature.append(candidata_temperatures[np.array(seperation_accuracy).argmax()])
99 |
100 | set_select_temperature = list(set(select_temperature))
101 | y_pred, y_true = [], []
102 | for _, (_, inputs, targets) in enumerate(test_till_now_loader):
103 | inputs = inputs.to(device)
104 | targets = targets.to(device)
105 | candiatetask = {i: [] for i in set_select_temperature}
106 | seperatePreds = []
107 |
108 | with torch.no_grad():
109 | image_features = ptvit(inputs, instance_tokens=vitpromptlist[taskid].weight, returnbeforepool=True)
110 | B = image_features.shape[0]
111 | for idx, fc in enumerate(all_classifiers[:taskid+1]):
112 | task_token = all_tokens[:taskid+1][idx].expand(B, -1, -1)
113 | task_token, attn, v = all_tabs[:taskid+1][idx](torch.cat((task_token, image_features), dim=1), mask_heads=None)
114 | task_token = task_token[:, 0]
115 | logit = fc(task_token)
116 | for tem in set_select_temperature:
117 | candiatetask[tem].append(torch.logsumexp(logit / tem, axis=-1))
118 | seperatePreds.append(logit.max(1)[1]+idx*logit.shape[1])
119 |
120 | candiatetask = {i: torch.stack(candiatetask[i]).T for i in set_select_temperature}
121 | seperatePreds = torch.stack(seperatePreds).T
122 |
123 | pred = []
124 | for tem in set_select_temperature:
125 | val, ind = candiatetask[tem].max(1)
126 | pred.append(ind)
127 | indexselection = torch.stack(pred, 1)
128 | selectid = torch.mode(indexselection, dim=1, keepdim=False)[0]
129 | outputs = []
130 | for row, idx in enumerate(selectid):
131 | outputs.append(seperatePreds[row][idx])
132 | outputs = torch.stack(outputs)
133 | y_pred.append(outputs.cpu().numpy())
134 | y_true.append(targets.cpu().numpy())
135 |
136 | y_pred = np.concatenate(y_pred)
137 | y_true = np.concatenate(y_true)
138 | if args.dil:
139 | accs.append(np.around((y_pred.T%args.max_cls == y_true%args.max_cls).sum() * 100 / len(y_true), decimals=2))
140 | else:
141 | accs.append(np.around((y_pred.T == y_true).sum() * 100 / len(y_true), decimals=2))
142 |
143 | tempacc = []
144 | for class_id in range(0, np.max(y_true), _total_classes-_known_classes):
145 | idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + _total_classes-_known_classes))[0]
146 | tempacc.append(np.around((y_pred[idxes] == y_true[idxes]).sum() * 100 / len(idxes), decimals=3))
147 | accuracy_table.append(tempacc)
148 |
149 | _known_classes = _total_classes
150 |
151 |
152 | np_acctable = np.zeros([taskid + 1, taskid + 1])
153 | for idxx, line in enumerate(accuracy_table):
154 | idxy = len(line)
155 | np_acctable[idxx, :idxy] = np.array(line)
156 | # import pdb;pdb.set_trace()
157 | np_acctable = np_acctable.T
158 | print("Accuracy table:")
159 | print(np_acctable)
160 | print("Accuracy curve:")
161 | print(accs)
162 | print("FAA: {}".format(accs[-1]))
163 |
164 | forgetting = np.mean((np.max(np_acctable, axis=1) - np_acctable[:, taskid])[:taskid])
165 | print("FF: {}".format(forgetting))
166 |
--------------------------------------------------------------------------------
/utils/data_manager.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import numpy as np
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms
6 | from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000
7 | from utils.data import iCIFAR100_vit, iGanFake, iGanClass, i5Datasets_vit, iImageNetR, iCore50, iDomainnetCIL
8 |
9 |
10 |
11 | class DataManager(object):
12 | def __init__(self, dataset_name, shuffle, seed, init_cls, increment, args=None):
13 | self.args = args
14 | self.dataset_name = dataset_name
15 | self._setup_data(dataset_name, shuffle, seed)
16 | assert init_cls <= len(self._class_order), 'No enough classes.'
17 | self._increments = [init_cls]
18 | while sum(self._increments) + increment < len(self._class_order):
19 | self._increments.append(increment)
20 | offset = len(self._class_order) - sum(self._increments)
21 | if offset > 0:
22 | self._increments.append(offset)
23 |
24 | self.attack = [
25 | transforms.ToTensor(),
26 | ]
27 |
28 |
29 | @property
30 | def nb_tasks(self):
31 | return len(self._increments)
32 |
33 | def get_task_size(self, task):
34 | return self._increments[task]
35 |
36 | def get_dataset(self, indices, source, mode, appendent=None, ret_data=False):
37 | if source == 'train':
38 | x, y = self._train_data, self._train_targets
39 | elif source == 'test':
40 | x, y = self._test_data, self._test_targets
41 | else:
42 | raise ValueError('Unknown data source {}.'.format(source))
43 |
44 | if mode == 'train':
45 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
46 | elif mode == 'flip':
47 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf])
48 | elif mode == 'test':
49 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
50 | elif mode == 'attack':
51 | trsf = transforms.Compose([*self._test_trsf, *self.attack])
52 |
53 | else:
54 | raise ValueError('Unknown mode {}.'.format(mode))
55 |
56 | data, targets = [], []
57 | for idx in indices:
58 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1)
59 | data.append(class_data)
60 | targets.append(class_targets)
61 |
62 | if appendent is not None and len(appendent) != 0:
63 | appendent_data, appendent_targets = appendent
64 | data.append(appendent_data)
65 | targets.append(appendent_targets)
66 |
67 | data, targets = np.concatenate(data), np.concatenate(targets)
68 |
69 | if ret_data:
70 | return data, targets, DummyDataset(data, targets, trsf, self.use_path)
71 | else:
72 | return DummyDataset(data, targets, trsf, self.use_path)
73 |
74 | def get_anchor_dataset(self, mode, appendent=None, ret_data=False):
75 | if mode == 'train':
76 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
77 | elif mode == 'flip':
78 | trsf = transforms.Compose([*self._test_trsf, transforms.RandomHorizontalFlip(p=1.), *self._common_trsf])
79 | elif mode == 'test':
80 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
81 | else:
82 | raise ValueError('Unknown mode {}.'.format(mode))
83 |
84 | data, targets = [], []
85 | if appendent is not None and len(appendent) != 0:
86 | appendent_data, appendent_targets = appendent
87 | data.append(appendent_data)
88 | targets.append(appendent_targets)
89 |
90 | data, targets = np.concatenate(data), np.concatenate(targets)
91 |
92 | if ret_data:
93 | return data, targets, DummyDataset(data, targets, trsf, self.use_path)
94 | else:
95 | return DummyDataset(data, targets, trsf, self.use_path)
96 |
97 | def get_dataset_with_split(self, indices, source, mode, appendent=None, val_samples_per_class=0):
98 | if source == 'train':
99 | x, y = self._train_data, self._train_targets
100 | elif source == 'test':
101 | x, y = self._test_data, self._test_targets
102 | else:
103 | raise ValueError('Unknown data source {}.'.format(source))
104 |
105 | if mode == 'train':
106 | trsf = transforms.Compose([*self._train_trsf, *self._common_trsf])
107 | elif mode == 'test':
108 | trsf = transforms.Compose([*self._test_trsf, *self._common_trsf])
109 | else:
110 | raise ValueError('Unknown mode {}.'.format(mode))
111 |
112 | train_data, train_targets = [], []
113 | val_data, val_targets = [], []
114 | for idx in indices:
115 | class_data, class_targets = self._select(x, y, low_range=idx, high_range=idx+1)
116 | val_indx = np.random.choice(len(class_data), val_samples_per_class, replace=False)
117 | train_indx = list(set(np.arange(len(class_data))) - set(val_indx))
118 | val_data.append(class_data[val_indx])
119 | val_targets.append(class_targets[val_indx])
120 | train_data.append(class_data[train_indx])
121 | train_targets.append(class_targets[train_indx])
122 |
123 | if appendent is not None:
124 | appendent_data, appendent_targets = appendent
125 | for idx in range(0, int(np.max(appendent_targets))+1):
126 | append_data, append_targets = self._select(appendent_data, appendent_targets,
127 | low_range=idx, high_range=idx+1)
128 | val_indx = np.random.choice(len(append_data), val_samples_per_class, replace=False)
129 | train_indx = list(set(np.arange(len(append_data))) - set(val_indx))
130 | val_data.append(append_data[val_indx])
131 | val_targets.append(append_targets[val_indx])
132 | train_data.append(append_data[train_indx])
133 | train_targets.append(append_targets[train_indx])
134 |
135 | train_data, train_targets = np.concatenate(train_data), np.concatenate(train_targets)
136 | val_data, val_targets = np.concatenate(val_data), np.concatenate(val_targets)
137 |
138 | return DummyDataset(train_data, train_targets, trsf, self.use_path), \
139 | DummyDataset(val_data, val_targets, trsf, self.use_path)
140 |
141 | def _setup_data(self, dataset_name, shuffle, seed):
142 | idata = _get_idata(dataset_name, self.args)
143 | idata.download_data()
144 |
145 | # Data
146 | self._train_data, self._train_targets = idata.train_data, idata.train_targets
147 | self._test_data, self._test_targets = idata.test_data, idata.test_targets
148 | self.use_path = idata.use_path
149 |
150 | # Transforms
151 | self._train_trsf = idata.train_trsf
152 | self._test_trsf = idata.test_trsf
153 | self._common_trsf = idata.common_trsf
154 |
155 | # Order
156 | order = [i for i in range(len(np.unique(self._train_targets)))]
157 | if shuffle:
158 | np.random.seed(seed)
159 | order = np.random.permutation(len(order)).tolist()
160 | else:
161 | order = idata.class_order
162 | self._class_order = order
163 | logging.info(self._class_order)
164 |
165 | # Map indices
166 | self._train_targets = _map_new_class_index(self._train_targets, self._class_order)
167 | self._test_targets = _map_new_class_index(self._test_targets, self._class_order)
168 |
169 | def _select(self, x, y, low_range, high_range):
170 | idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0]
171 | return x[idxes], y[idxes]
172 |
173 |
174 | class DummyDataset(Dataset):
175 | def __init__(self, images, labels, trsf, use_path=False):
176 | assert len(images) == len(labels), 'Data size error!'
177 | self.images = images
178 | self.labels = labels
179 | self.trsf = trsf
180 | self.use_path = use_path
181 |
182 | def __len__(self):
183 | return len(self.images)
184 |
185 | def __getitem__(self, idx):
186 | if self.use_path:
187 | image = self.trsf(pil_loader(self.images[idx]))
188 | else:
189 | image = self.trsf(Image.fromarray(self.images[idx]))
190 | label = self.labels[idx]
191 |
192 | return idx, image, label
193 |
194 |
195 | def _map_new_class_index(y, order):
196 | return np.array(list(map(lambda x: order.index(x), y)))
197 |
198 |
199 | def _get_idata(dataset_name, args=None):
200 | name = dataset_name.lower()
201 | if name == 'cifar10':
202 | return iCIFAR10()
203 | elif name == 'cifar100':
204 | return iCIFAR100()
205 | elif name == 'imagenet1000':
206 | return iImageNet1000()
207 | elif name == "imagenet100":
208 | return iImageNet100()
209 | elif name == "cifar100_vit":
210 | return iCIFAR100_vit()
211 | elif name == "5datasets_vit":
212 | return i5Datasets_vit()
213 | elif name == "core50":
214 | return iCore50()
215 | elif name == "ganfake":
216 | return iGanFake(args)
217 | elif name == "imagenetr":
218 | return iImageNetR()
219 | elif name == "ganclass":
220 | return iGanClass(args)
221 | elif name == "domainnet":
222 | return iDomainnetCIL()
223 | else:
224 | raise NotImplementedError('Unknown dataset {}.'.format(dataset_name))
225 |
226 |
227 | def pil_loader(path):
228 | '''
229 | Ref:
230 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
231 | '''
232 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
233 | with open(path, 'rb') as f:
234 | img = Image.open(f)
235 | return img.convert('RGB')
236 |
237 |
238 | def accimage_loader(path):
239 | '''
240 | Ref:
241 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
242 | accimage is an accelerated Image loader and preprocessor leveraging Intel IPP.
243 | accimage is available on conda-forge.
244 | '''
245 | import accimage
246 | try:
247 | return accimage.Image(path)
248 | except IOError:
249 | # Potentially a decoding problem, fall back to PIL.Image
250 | return pil_loader(path)
251 |
252 |
253 | def default_loader(path):
254 | '''
255 | Ref:
256 | https://pytorch.org/docs/stable/_modules/torchvision/datasets/folder.html#ImageFolder
257 | '''
258 | from torchvision import get_image_backend
259 | if get_image_backend() == 'accimage':
260 | return accimage_loader(path)
261 | else:
262 | return pil_loader(path)
263 |
--------------------------------------------------------------------------------
/utils/datautils/core50/core50data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | ################################################################################
5 | # Copyright (c) 2019. Vincenzo Lomonaco. All rights reserved. #
6 | # See the accompanying LICENSE file for terms. #
7 | # #
8 | # Date: 23-07-2019 #
9 | # Author: Vincenzo Lomonaco #
10 | # E-mail: vincenzo.lomonaco@unibo.it #
11 | # Website: vincenzolomonaco.com #
12 | ################################################################################
13 |
14 | """ Data Loader for the CORe50 Dataset """
15 |
16 | # Python 2-3 compatible
17 | from __future__ import print_function
18 | from __future__ import division
19 | from __future__ import absolute_import
20 |
21 | # other imports
22 | import numpy as np
23 | import pickle as pkl
24 | import os
25 | import logging
26 | from hashlib import md5
27 | from PIL import Image
28 |
29 |
30 | class CORE50(object):
31 | """ CORe50 Data Loader calss
32 | Args:
33 | root (string): Root directory of the dataset where ``core50_128x128``,
34 | ``paths.pkl``, ``LUP.pkl``, ``labels.pkl``, ``core50_imgs.npz``
35 | live. For example ``~/data/core50``.
36 | preload (string, optional): If True data is pre-loaded with look-up
37 | tables. RAM usage may be high.
38 | scenario (string, optional): One of the three scenarios of the CORe50
39 | benchmark ``ni``, ``nc``, ``nic``, `nicv2_79`,``nicv2_196`` and
40 | ``nicv2_391``.
41 | train (bool, optional): If True, creates the dataset from the training
42 | set, otherwise creates from test set.
43 | cumul (bool, optional): If True the cumulative scenario is assumed, the
44 | incremental scenario otherwise. Practically speaking ``cumul=True``
45 | means that for batch=i also batch=0,...i-1 will be added to the
46 | available training data.
47 | run (int, optional): One of the 10 runs (from 0 to 9) in which the
48 | training batch order is changed as in the official benchmark.
49 | start_batch (int, optional): One of the training incremental batches
50 | from 0 to max-batch - 1. Remember that for the ``ni``, ``nc`` and
51 | ``nic`` we have respectively 8, 9 and 79 incremental batches. If
52 | ``train=False`` this parameter will be ignored.
53 | """
54 |
55 | nbatch = {
56 | 'ni': 8,
57 | 'nc': 9,
58 | 'nic': 79,
59 | 'nicv2_79': 79,
60 | 'nicv2_196': 196,
61 | 'nicv2_391': 391
62 | }
63 |
64 | def __init__(self, root='', preload=False, scenario='ni', cumul=False,
65 | run=0, start_batch=0):
66 | """" Initialize Object """
67 |
68 | self.root = os.path.expanduser(root)
69 | self.preload = preload
70 | self.scenario = scenario
71 | self.cumul = cumul
72 | self.run = run
73 | self.batch = start_batch
74 |
75 | if preload:
76 | print("Loading data...")
77 | bin_path = os.path.join(root, 'core50_imgs.bin')
78 | if os.path.exists(bin_path):
79 | with open(bin_path, 'rb') as f:
80 | self.x = np.fromfile(f, dtype=np.uint8) \
81 | .reshape(164866, 128, 128, 3)
82 |
83 | else:
84 | with open(os.path.join(root, 'core50_imgs.npz'), 'rb') as f:
85 | npzfile = np.load(f)
86 | self.x = npzfile['x']
87 | print("Writing bin for fast reloading...")
88 | self.x.tofile(bin_path)
89 |
90 | print("Loading paths...")
91 | with open(os.path.join(root, 'paths.pkl'), 'rb') as f:
92 | self.paths = pkl.load(f)
93 |
94 | print("Loading LUP...")
95 | with open(os.path.join(root, 'LUP.pkl'), 'rb') as f:
96 | self.LUP = pkl.load(f)
97 |
98 | print("Loading labels...")
99 | with open(os.path.join(root, 'labels.pkl'), 'rb') as f:
100 | self.labels = pkl.load(f)
101 |
102 | def __iter__(self):
103 | return self
104 |
105 | def get_data_batchidx(self, idx):
106 |
107 | scen = self.scenario
108 | run = self.run
109 | batch = idx
110 |
111 | if self.batch == self.nbatch[scen]:
112 | raise StopIteration
113 |
114 | # Getting the right indexis
115 | if self.cumul:
116 | train_idx_list = []
117 | for i in range(self.batch + 1):
118 | train_idx_list += self.LUP[scen][run][i]
119 | else:
120 | train_idx_list = self.LUP[scen][run][batch]
121 |
122 | # loading data
123 | if self.preload:
124 | train_x = np.take(self.x, train_idx_list, axis=0)\
125 | .astype(np.float32)
126 | else:
127 | print("Loading data...")
128 | # Getting the actual paths
129 | train_paths = []
130 | for idx in train_idx_list:
131 | train_paths.append(os.path.join(self.root, self.paths[idx]))
132 | # loading imgs
133 | train_x = self.get_batch_from_paths(train_paths).astype(np.float32)
134 |
135 | # In either case we have already loaded the y
136 | if self.cumul:
137 | train_y = []
138 | for i in range(self.batch + 1):
139 | train_y += self.labels[scen][run][i]
140 | else:
141 | train_y = self.labels[scen][run][batch]
142 |
143 | train_y = np.asarray(train_y, dtype=np.int)
144 |
145 | return (train_x, train_y)
146 |
147 | def __next__(self):
148 | """ Next batch based on the object parameter which can be also changed
149 | from the previous iteration. """
150 |
151 | scen = self.scenario
152 | run = self.run
153 | batch = self.batch
154 |
155 | if self.batch == self.nbatch[scen]:
156 | raise StopIteration
157 |
158 | # Getting the right indexis
159 | if self.cumul:
160 | train_idx_list = []
161 | for i in range(self.batch + 1):
162 | train_idx_list += self.LUP[scen][run][i]
163 | else:
164 | train_idx_list = self.LUP[scen][run][batch]
165 |
166 | # loading data
167 | if self.preload:
168 | train_x = np.take(self.x, train_idx_list, axis=0)\
169 | .astype(np.float32)
170 | else:
171 | print("Loading data...")
172 | # Getting the actual paths
173 | train_paths = []
174 | for idx in train_idx_list:
175 | train_paths.append(os.path.join(self.root, self.paths[idx]))
176 | # loading imgs
177 | train_x = self.get_batch_from_paths(train_paths).astype(np.float32)
178 |
179 | # In either case we have already loaded the y
180 | if self.cumul:
181 | train_y = []
182 | for i in range(self.batch + 1):
183 | train_y += self.labels[scen][run][i]
184 | else:
185 | train_y = self.labels[scen][run][batch]
186 |
187 | train_y = np.asarray(train_y, dtype=np.int)
188 |
189 | # Update state for next iter
190 | self.batch += 1
191 |
192 | return (train_x, train_y)
193 |
194 | def get_test_set(self):
195 | """ Return the test set (the same for each inc. batch). """
196 |
197 | scen = self.scenario
198 | run = self.run
199 |
200 | test_idx_list = self.LUP[scen][run][-1]
201 |
202 | if self.preload:
203 | test_x = np.take(self.x, test_idx_list, axis=0).astype(np.float32)
204 | else:
205 | # test paths
206 | test_paths = []
207 | for idx in test_idx_list:
208 | test_paths.append(os.path.join(self.root, self.paths[idx]))
209 |
210 | # test imgs
211 | test_x = self.get_batch_from_paths(test_paths).astype(np.float32)
212 |
213 | test_y = self.labels[scen][run][-1]
214 | test_y = np.asarray(test_y, dtype=np.int)
215 |
216 | return test_x, test_y
217 |
218 | next = __next__ # python2.x compatibility.
219 |
220 | @staticmethod
221 | def get_batch_from_paths(paths, compress=False, snap_dir='',
222 | on_the_fly=True, verbose=False):
223 | """ Given a number of abs. paths it returns the numpy array
224 | of all the images. """
225 |
226 | # Getting root logger
227 | log = logging.getLogger('mylogger')
228 |
229 | # If we do not process data on the fly we check if the same train
230 | # filelist has been already processed and saved. If so, we load it
231 | # directly. In either case we end up returning x and y, as the full
232 | # training set and respective labels.
233 | num_imgs = len(paths)
234 | hexdigest = md5(''.join(paths).encode('utf-8')).hexdigest()
235 | log.debug("Paths Hex: " + str(hexdigest))
236 | loaded = False
237 | x = None
238 | file_path = None
239 |
240 | if compress:
241 | file_path = snap_dir + hexdigest + ".npz"
242 | if os.path.exists(file_path) and not on_the_fly:
243 | loaded = True
244 | with open(file_path, 'rb') as f:
245 | npzfile = np.load(f)
246 | x, y = npzfile['x']
247 | else:
248 | x_file_path = snap_dir + hexdigest + "_x.bin"
249 | if os.path.exists(x_file_path) and not on_the_fly:
250 | loaded = True
251 | with open(x_file_path, 'rb') as f:
252 | x = np.fromfile(f, dtype=np.uint8) \
253 | .reshape(num_imgs, 128, 128, 3)
254 |
255 | # Here we actually load the images.
256 | if not loaded:
257 | # Pre-allocate numpy arrays
258 | x = np.zeros((num_imgs, 128, 128, 3), dtype=np.uint8)
259 |
260 | for i, path in enumerate(paths):
261 | if verbose:
262 | print("\r" + path + " processed: " + str(i + 1), end='')
263 | x[i] = np.array(Image.open(path))
264 |
265 | if verbose:
266 | print()
267 |
268 | if not on_the_fly:
269 | # Then we save x
270 | if compress:
271 | with open(file_path, 'wb') as g:
272 | np.savez_compressed(g, x=x)
273 | else:
274 | x.tofile(snap_dir + hexdigest + "_x.bin")
275 |
276 | assert (x is not None), 'Problems loading data. x is None!'
277 |
278 | return x
279 |
280 |
281 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torchvision import datasets, transforms
4 | from utils.toolkit import split_images_labels
5 | from PIL import Image
6 |
7 |
8 | class iData(object):
9 | train_trsf = []
10 | test_trsf = []
11 | common_trsf = []
12 | class_order = None
13 |
14 |
15 | class iCIFAR10(iData):
16 | use_path = False
17 | train_trsf = [
18 | transforms.RandomCrop(32, padding=4),
19 | transforms.RandomHorizontalFlip(p=0.5),
20 | transforms.ColorJitter(brightness=63/255)
21 | ]
22 | test_trsf = []
23 | common_trsf = [
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
26 | ]
27 |
28 | class_order = np.arange(10).tolist()
29 |
30 | def download_data(self):
31 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True)
32 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True)
33 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
34 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
35 |
36 | class iCIFAR100(iData):
37 | use_path = False
38 | train_trsf = [
39 | transforms.RandomCrop(32, padding=4),
40 | transforms.RandomHorizontalFlip(),
41 | transforms.ColorJitter(brightness=63/255)
42 | ]
43 | test_trsf = []
44 | common_trsf = [
45 | transforms.ToTensor(),
46 | transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),
47 | ]
48 |
49 | class_order = np.arange(100).tolist()
50 |
51 | def download_data(self):
52 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True)
53 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True)
54 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
55 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
56 |
57 | class iImageNet1000(iData):
58 | use_path = True
59 | train_trsf = [
60 | transforms.RandomResizedCrop(224),
61 | transforms.RandomHorizontalFlip(),
62 | transforms.ColorJitter(brightness=63/255)
63 | ]
64 | test_trsf = [
65 | transforms.Resize(256),
66 | transforms.CenterCrop(224),
67 | ]
68 | common_trsf = [
69 | transforms.ToTensor(),
70 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
71 | ]
72 |
73 | class_order = np.arange(1000).tolist()
74 |
75 | def download_data(self):
76 | # assert 0,"You should specify the folder of your dataset"
77 | train_dir = '/home/wangyabin/workspace/data/train'
78 | test_dir = '/home/wangyabin/workspace/data/val'
79 |
80 | train_dset = datasets.ImageFolder(train_dir)
81 | test_dset = datasets.ImageFolder(test_dir)
82 |
83 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
84 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
85 |
86 | class iImageNet100(iData):
87 | use_path = True
88 | train_trsf = [
89 | transforms.RandomResizedCrop(224),
90 | transforms.RandomHorizontalFlip(),
91 | ]
92 | test_trsf = [
93 | transforms.Resize(256),
94 | transforms.CenterCrop(224),
95 | ]
96 | common_trsf = [
97 | transforms.ToTensor(),
98 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
99 | ]
100 |
101 | class_order = np.arange(1000).tolist()
102 |
103 | def download_data(self):
104 | train_dir = '/home/wangyabin/workspace/datasets/imagenet100/train/'
105 | test_dir = '/home/wangyabin/workspace/datasets/imagenet100/val/'
106 |
107 | train_dset = datasets.ImageFolder(train_dir)
108 | test_dset = datasets.ImageFolder(test_dir)
109 |
110 | self.train_data, self.train_targets = split_images_labels(train_dset.imgs)
111 | self.test_data, self.test_targets = split_images_labels(test_dset.imgs)
112 |
113 | class iCore50(iData):
114 |
115 | use_path = False
116 | train_trsf = [
117 | transforms.RandomResizedCrop(224),
118 | transforms.RandomHorizontalFlip(),
119 | ]
120 | test_trsf = [
121 | transforms.Resize(256),
122 | transforms.CenterCrop(224),
123 | ]
124 | common_trsf = [
125 | transforms.ToTensor(),
126 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
127 | ]
128 |
129 | class_order = np.arange(8*50).tolist()
130 |
131 | def download_data(self):
132 | from utils.datautils.core50.core50data import CORE50
133 | datagen = CORE50(root='/home/wangyabin/workspace/datasets/core50/data/core50_128x128', scenario="ni")
134 |
135 | dataset_list = []
136 | for i, train_batch in enumerate(datagen):
137 | imglist, labellist = train_batch
138 | labellist += i*50
139 | imglist = imglist.astype(np.uint8)
140 | dataset_list.append([imglist, labellist])
141 | train_x = np.concatenate(np.array(dataset_list)[:, 0])
142 | train_y = np.concatenate(np.array(dataset_list)[:, 1])
143 | self.train_data = train_x
144 | self.train_targets = train_y
145 |
146 | test_x, test_y = datagen.get_test_set()
147 | test_x = test_x.astype(np.uint8)
148 | self.test_data = test_x
149 | self.test_targets = test_y
150 | # import pdb;pdb.set_trace()
151 |
152 |
153 |
154 | class iDomainnetCIL(iData):
155 | use_path = True
156 | train_trsf = [
157 | transforms.RandomResizedCrop(224),
158 | transforms.RandomHorizontalFlip(),
159 | ]
160 | test_trsf = [
161 | transforms.Resize(256),
162 | transforms.CenterCrop(224),
163 | ]
164 | common_trsf = [
165 | transforms.ToTensor(),
166 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
167 | ]
168 |
169 | class_order = np.arange(200).tolist()
170 |
171 | def download_data(self):
172 | rootdir = '/home/wangyabin/workspace/datasets/domainnet'
173 |
174 | train_txt = './utils/datautils/domainnet/train.txt'
175 | test_txt = './utils/datautils/domainnet/test.txt'
176 |
177 | train_images = []
178 | train_labels = []
179 | with open(train_txt, 'r') as dict_file:
180 | for line in dict_file:
181 | (value, key) = line.strip().split(' ')
182 | train_images.append(os.path.join(rootdir, value))
183 | train_labels.append(int(key))
184 | train_images = np.array(train_images)
185 | train_labels = np.array(train_labels)
186 | test_images = []
187 | test_labels = []
188 | with open(test_txt, 'r') as dict_file:
189 | for line in dict_file:
190 | (value, key) = line.strip().split(' ')
191 | test_images.append(os.path.join(rootdir, value))
192 | test_labels.append(int(key))
193 | test_images = np.array(test_images)
194 | test_labels = np.array(test_labels)
195 |
196 | self.train_data = train_images
197 | self.train_targets = train_labels
198 | self.test_data = test_images
199 | self.test_targets = test_labels
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 | class iImageNetR(iData):
209 | use_path = True
210 | train_trsf = [
211 | transforms.RandomResizedCrop(224),
212 | transforms.RandomHorizontalFlip(),
213 | ]
214 | test_trsf = [
215 | transforms.Resize(256),
216 | transforms.CenterCrop(224),
217 | ]
218 | common_trsf = [
219 | transforms.ToTensor(),
220 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
221 | ]
222 |
223 | class_order = np.arange(200).tolist()
224 |
225 | # first we need get the pre-processed txt files which containing the taining and text split of DualPrompt.
226 | # the origin data structure is tfds, but we extract its data to a txt file
227 | """
228 | we use following code to extract split
229 | def get_stats(split_name, filepath):
230 | stats = []
231 | ds = dataset_builder.as_dataset(split=split_name)
232 | label_list = []
233 | for batch in ds:
234 | label_list.append(int(batch["label"]))
235 | label_list = list(set(label_list))
236 | data_dict = {i:[] for i in label_list}
237 | for batch in ds:
238 | data_dict[int(batch["label"])].append(batch["file_name"].numpy())
239 | print(len(label_list))
240 | label_list.sort()
241 | with open(filepath, 'w') as f:
242 | for i in label_list:
243 | for line in data_dict[i]:
244 | f.write(str(IR_LABEL_MAP[i]) + "\t" + line.decode("utf-8") +"\n")
245 | return data_dict
246 | train_stats = get_stats("test[:80%]", "train.txt")
247 | test_stats = get_stats("test[80%:]", "text.txt")
248 | """
249 |
250 | def download_data(self):
251 | rootdir = '/home/wangyabin/workspace/datasets/imagenet-r'
252 |
253 | train_txt = './utils/datautils/imagenet-r/train.txt'
254 | test_txt = './utils/datautils/imagenet-r/test.txt'
255 |
256 | train_images = []
257 | train_labels = []
258 | with open(train_txt, 'r') as dict_file:
259 | for line in dict_file:
260 | (key, value) = line.strip().split('\t')
261 | train_images.append(os.path.join(rootdir, value))
262 | train_labels.append(int(key))
263 | train_images = np.array(train_images)
264 | train_labels = np.array(train_labels)
265 |
266 | test_images = []
267 | test_labels = []
268 | with open(test_txt, 'r') as dict_file:
269 | for line in dict_file:
270 | (key, value) = line.strip().split('\t')
271 | test_images.append(os.path.join(rootdir, value))
272 | test_labels.append(int(key))
273 | test_images = np.array(test_images)
274 | test_labels = np.array(test_labels)
275 |
276 | self.train_data = train_images
277 | self.train_targets = train_labels
278 | self.test_data = test_images
279 | self.test_targets = test_labels
280 |
281 | class iCIFAR100_vit(iData):
282 | use_path = False
283 | train_trsf = [
284 | transforms.Resize(256),
285 | transforms.ColorJitter(brightness=63 / 255),
286 | transforms.RandomResizedCrop(224),
287 | transforms.RandomHorizontalFlip(),
288 | # transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
289 | ]
290 | test_trsf = [
291 | transforms.Resize(256),
292 | transforms.CenterCrop(224),
293 | ]
294 | common_trsf = [
295 | transforms.ToTensor(),
296 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
297 | ]
298 |
299 | class_order = np.arange(100).tolist()
300 | # class_order = [51, 48, 73, 93, 39, 67, 29, 49, 57, 33, 4, 32, 5, 75, 63, 7, 61, 36, 69, 62, 46, 30, 25, 47, 12, 11, 94, 18, 27, 88, 0, 99, 21, 87, 34, 24, 86, 35, 22, 42, 66, 64, 2, 97, 98, 96, 71, 14, 95, 37, 54, 31, 10, 20, 52, 79, 60, 72, 41, 91, 44, 15, 16, 83, 59, 6, 82, 45, 81, 13, 53, 28, 50, 17, 19, 85, 1, 77, 70, 58, 38, 43, 80, 26, 9, 55, 92, 3, 89, 40, 76, 74, 65, 90, 84, 23, 8, 78, 56, 68]
301 |
302 | def download_data(self):
303 | train_dataset = datasets.cifar.CIFAR100('./data', train=True, download=True)
304 | test_dataset = datasets.cifar.CIFAR100('./data', train=False, download=True)
305 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
306 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
307 |
308 | class i5Datasets_vit(iData):
309 | use_path = False
310 | train_trsf = [
311 | transforms.Resize(224),
312 | # transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
313 | transforms.RandomResizedCrop(224),
314 | transforms.RandomHorizontalFlip(),
315 | transforms.ColorJitter(brightness=63 / 255)
316 | ]
317 | test_trsf = [
318 | transforms.Resize(256),
319 | transforms.CenterCrop(224),
320 | ]
321 | common_trsf = [
322 | transforms.ToTensor(),
323 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
324 | ]
325 |
326 |
327 | class_order = np.arange(50).tolist()
328 |
329 |
330 | def download_data(self):
331 | img_size=64
332 | train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True)
333 | test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True)
334 |
335 | trainlist = []
336 | testlist = []
337 | train_label_list = []
338 | test_label_list = []
339 |
340 | # cifar10
341 | cifar10_train_dataset = datasets.cifar.CIFAR10('./data', train=True, download=True)
342 | cifar10_test_dataset = datasets.cifar.CIFAR10('./data', train=False, download=True)
343 | for img, target in zip(cifar10_train_dataset.data, cifar10_train_dataset.targets):
344 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size))))
345 | train_label_list.append(target)
346 | for img, target in zip(cifar10_test_dataset.data, cifar10_test_dataset.targets):
347 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size))))
348 | test_label_list.append(target)
349 |
350 | # MNIST
351 | minist_train_dataset = datasets.MNIST('./data', train=True, download=True)
352 | minist_test_dataset = datasets.MNIST('./data', train=False, download=True)
353 | for img, target in zip(minist_train_dataset.data.numpy(), minist_train_dataset.targets.numpy()):
354 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB')))
355 | train_label_list.append(target+10)
356 | for img, target in zip(minist_test_dataset.data.numpy(), minist_test_dataset.targets.numpy()):
357 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB')))
358 | test_label_list.append(target+10)
359 |
360 | # notMNIST
361 | classes = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
362 | tarin_dir = "./data/notMNIST_large"
363 | test_dir = "./data/notMNIST_small"
364 | for idx, cls in enumerate(classes):
365 | image_files = os.listdir(os.path.join(tarin_dir, cls))
366 | for img_path in image_files:
367 | try:
368 | image = np.array(Image.open(os.path.join(tarin_dir, cls, img_path)).resize((img_size, img_size)).convert('RGB'))
369 | trainlist.append(image)
370 | train_label_list.append(idx+20)
371 | except:
372 | print(os.path.join(tarin_dir, cls, img_path))
373 | image_files = os.listdir(os.path.join(test_dir, cls))
374 | for img_path in image_files:
375 | try:
376 | image = np.array(Image.open(os.path.join(test_dir, cls, img_path)).resize((img_size, img_size)).convert('RGB'))
377 | testlist.append(image)
378 | test_label_list.append(idx+20)
379 | except:
380 | print(os.path.join(test_dir, cls, img_path))
381 |
382 |
383 | # Fashion-MNIST
384 | fminist_train_dataset = datasets.FashionMNIST('./data', train=True, download=True)
385 | fminist_test_dataset = datasets.FashionMNIST('./data', train=False, download=True)
386 | for img, target in zip(fminist_train_dataset.data.numpy(), fminist_train_dataset.targets.numpy()):
387 | trainlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB')))
388 | train_label_list.append(target+30)
389 | for img, target in zip(fminist_test_dataset.data.numpy(), fminist_test_dataset.targets.numpy()):
390 | testlist.append(np.array(Image.fromarray(img).resize((img_size, img_size)).convert('RGB')))
391 | test_label_list.append(target+30)
392 |
393 | # SVHN
394 | svhn_train_dataset = datasets.SVHN('./data', split='train', download=True)
395 | svhn_test_dataset = datasets.SVHN('./data', split='test', download=True)
396 | for img, target in zip(svhn_train_dataset.data, svhn_train_dataset.labels):
397 | trainlist.append(np.array(Image.fromarray(img.transpose(1,2,0)).resize((img_size, img_size))))
398 | train_label_list.append(target+40)
399 | for img, target in zip(svhn_test_dataset.data, svhn_test_dataset.labels):
400 | testlist.append(np.array(Image.fromarray(img.transpose(1,2,0)).resize((img_size, img_size))))
401 | test_label_list.append(target+40)
402 |
403 |
404 | train_dataset.data = np.array(trainlist)
405 | train_dataset.targets = np.array(train_label_list)
406 | test_dataset.data = np.array(testlist)
407 | test_dataset.targets = np.array(test_label_list)
408 |
409 | self.train_data, self.train_targets = train_dataset.data, np.array(train_dataset.targets)
410 | self.test_data, self.test_targets = test_dataset.data, np.array(test_dataset.targets)
411 |
412 | class iGanFake(object):
413 | use_path = True
414 | train_trsf = [
415 | transforms.RandomResizedCrop(224),
416 | transforms.RandomHorizontalFlip(),
417 | transforms.ColorJitter(brightness=63/255)
418 | ]
419 | test_trsf = [
420 | transforms.Resize(256),
421 | transforms.CenterCrop(224),
422 | ]
423 | common_trsf = [
424 | transforms.ToTensor(),
425 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
426 | ]
427 |
428 | def __init__(self, args):
429 | self.args = args
430 | class_order = args["class_order"]
431 | self.class_order = class_order
432 |
433 | def download_data(self):
434 |
435 | train_dataset = []
436 | test_dataset = []
437 | for id, name in enumerate(self.args["task_name"]):
438 | root_ = os.path.join(self.args["data_path"], name, 'train')
439 | sub_classes = os.listdir(root_) if self.args["multiclass"][id] else ['']
440 | for cls in sub_classes:
441 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')):
442 | train_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id))
443 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')):
444 | train_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id))
445 |
446 | for id, name in enumerate(self.args["task_name"]):
447 | root_ = os.path.join(self.args["data_path"], name, 'val')
448 | sub_classes = os.listdir(root_) if self.args["multiclass"][id] else ['']
449 | for cls in sub_classes:
450 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')):
451 | test_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id))
452 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')):
453 | test_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id))
454 |
455 | self.train_data, self.train_targets = split_images_labels(train_dataset)
456 | self.test_data, self.test_targets = split_images_labels(test_dataset)
457 |
458 | class iGanClass(object):
459 | use_path = True
460 | train_trsf = [
461 | transforms.Resize(256),
462 | transforms.RandomResizedCrop(224),
463 | transforms.RandomHorizontalFlip(),
464 | transforms.ColorJitter(brightness=63/255)
465 | ]
466 | test_trsf = [
467 | transforms.Resize(256),
468 | transforms.CenterCrop(224),
469 | ]
470 | common_trsf = [
471 | transforms.ToTensor(),
472 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
473 | ]
474 |
475 | def __init__(self, args):
476 | self.args = args
477 | class_number = args["class_number"]
478 | self.class_order = [i for i in range(2*class_number)]
479 | self.task_name = [str(i) for i in range(class_number)]
480 |
481 | def download_data(self):
482 |
483 | train_dataset = []
484 | test_dataset = []
485 | for id, name in enumerate(self.task_name):
486 | root_ = os.path.join(self.args["data_path"], name, 'train')
487 | sub_classes = ['']
488 | for cls in sub_classes:
489 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')):
490 | train_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id))
491 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')):
492 | train_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id))
493 |
494 | for id, name in enumerate(self.task_name):
495 | root_ = os.path.join(self.args["data_path"], name, 'val')
496 | sub_classes = ['']
497 | for cls in sub_classes:
498 | for imgname in os.listdir(os.path.join(root_, cls, '0_real')):
499 | test_dataset.append((os.path.join(root_, cls, '0_real', imgname), 0 + 2 * id))
500 | for imgname in os.listdir(os.path.join(root_, cls, '1_fake')):
501 | test_dataset.append((os.path.join(root_, cls, '1_fake', imgname), 1 + 2 * id))
502 |
503 | self.train_data, self.train_targets = split_images_labels(train_dataset)
504 | self.test_data, self.test_targets = split_images_labels(test_dataset)
--------------------------------------------------------------------------------
/models/convit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the CC-by-NC license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 |
8 | '''These modules are adapted from those of timm, see
9 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
10 | '''
11 |
12 | import copy
13 | import math
14 | from functools import lru_cache
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
20 |
21 | def freeze_parameters(m, requires_grad=False):
22 | if m is None:
23 | return
24 |
25 | if isinstance(m, nn.Parameter):
26 | m.requires_grad = requires_grad
27 | else:
28 | for p in m.parameters():
29 | p.requires_grad = requires_grad
30 |
31 | class BatchEnsemble(nn.Module):
32 | def __init__(self, in_features, out_features, bias=True):
33 | super().__init__()
34 |
35 | self.linear = nn.Linear(in_features, out_features, bias=bias)
36 |
37 | self.out_features, self.in_features = out_features, in_features
38 | self.bias = bias
39 |
40 | self.r = nn.Parameter(torch.randn(self.out_features))
41 | self.s = nn.Parameter(torch.randn(self.in_features))
42 |
43 | def __deepcopy__(self, memo):
44 | cls = self.__class__
45 | result = cls.__new__(cls, self.in_features, self.out_features, self.bias)
46 | memo[id(self)] = result
47 | for k, v in self.__dict__.items():
48 | setattr(result, k, copy.deepcopy(v, memo))
49 |
50 | result.linear.weight = self.linear.weight
51 | return result
52 |
53 | def reset_parameters(self):
54 | device = self.linear.weight.device
55 | self.r = nn.Parameter(torch.randn(self.out_features).to(device))
56 | self.s = nn.Parameter(torch.randn(self.in_features).to(device))
57 |
58 | if self.bias:
59 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.linear.weight)
60 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
61 | nn.init.uniform_(self.linear.bias, -bound, bound)
62 |
63 | def forward(self, x):
64 | w = torch.outer(self.r, self.s)
65 | w = w * self.linear.weight
66 | return F.linear(x, w, self.linear.bias)
67 |
68 |
69 | class Mlp(nn.Module):
70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., fc=nn.Linear):
71 | super().__init__()
72 | out_features = out_features or in_features
73 | hidden_features = hidden_features or in_features
74 | self.fc1 = fc(in_features, hidden_features)
75 | self.act = act_layer()
76 | self.fc2 = fc(hidden_features, out_features)
77 | self.drop = nn.Dropout(drop)
78 | self.apply(self._init_weights)
79 |
80 | def _init_weights(self, m):
81 | if isinstance(m, nn.Linear):
82 | trunc_normal_(m.weight, std=.02)
83 | if isinstance(m, nn.Linear) and m.bias is not None:
84 | nn.init.constant_(m.bias, 0)
85 | elif isinstance(m, BatchEnsemble):
86 | trunc_normal_(m.linear.weight, std=.02)
87 | if isinstance(m.linear, nn.Linear) and m.linear.bias is not None:
88 | nn.init.constant_(m.linear.bias, 0)
89 | elif isinstance(m, nn.LayerNorm):
90 | nn.init.constant_(m.bias, 0)
91 | nn.init.constant_(m.weight, 1.0)
92 |
93 | def forward(self, x):
94 | x = self.fc1(x)
95 | x = self.act(x)
96 | x = self.drop(x)
97 | x = self.fc2(x)
98 | x = self.drop(x)
99 | return x
100 |
101 |
102 | class GPSA(nn.Module):
103 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
104 | locality_strength=1., use_local_init=True, fc=None):
105 | super().__init__()
106 | self.num_heads = num_heads
107 | self.dim = dim
108 | head_dim = dim // num_heads
109 | self.scale = qk_scale or head_dim ** -0.5
110 |
111 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
112 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
113 |
114 | self.attn_drop = nn.Dropout(attn_drop)
115 | self.proj = nn.Linear(dim, dim)
116 | self.pos_proj = nn.Linear(3, num_heads)
117 | self.proj_drop = nn.Dropout(proj_drop)
118 | self.locality_strength = locality_strength
119 | self.gating_param = nn.Parameter(torch.ones(self.num_heads))
120 | self.apply(self._init_weights)
121 | if use_local_init:
122 | self.local_init(locality_strength=locality_strength)
123 |
124 | def reset_parameters(self):
125 | self.apply(self._init_weights)
126 |
127 | def _init_weights(self, m):
128 | if isinstance(m, nn.Linear):
129 | trunc_normal_(m.weight, std=.02)
130 | if isinstance(m, nn.Linear) and m.bias is not None:
131 | nn.init.constant_(m.bias, 0)
132 | elif isinstance(m, nn.LayerNorm):
133 | nn.init.constant_(m.bias, 0)
134 | nn.init.constant_(m.weight, 1.0)
135 |
136 | def forward(self, x):
137 | B, N, C = x.shape
138 | if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N:
139 | self.get_rel_indices(N)
140 |
141 | attn = self.get_attention(x)
142 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
143 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
144 | x = self.proj(x)
145 | x = self.proj_drop(x)
146 | return x, attn, v
147 |
148 | def get_attention(self, x):
149 | B, N, C = x.shape
150 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151 | q, k = qk[0], qk[1]
152 | pos_score = self.rel_indices.expand(B, -1, -1,-1)
153 | pos_score = self.pos_proj(pos_score).permute(0,3,1,2)
154 | patch_score = (q @ k.transpose(-2, -1)) * self.scale
155 | patch_score = patch_score.softmax(dim=-1)
156 | pos_score = pos_score.softmax(dim=-1)
157 |
158 | gating = self.gating_param.view(1,-1,1,1)
159 | attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
160 | attn /= attn.sum(dim=-1).unsqueeze(-1)
161 | attn = self.attn_drop(attn)
162 | return attn
163 |
164 | def get_attention_map(self, x, return_map = False):
165 |
166 | attn_map = self.get_attention(x).mean(0) # average over batch
167 | distances = self.rel_indices.squeeze()[:,:,-1]**.5
168 | dist = torch.einsum('nm,hnm->h', (distances, attn_map))
169 | dist /= distances.size(0)
170 | if return_map:
171 | return dist, attn_map
172 | else:
173 | return dist
174 |
175 | def local_init(self, locality_strength=1.):
176 | self.v.weight.data.copy_(torch.eye(self.dim))
177 | locality_distance = 1 #max(1,1/locality_strength**.5)
178 |
179 | kernel_size = int(self.num_heads**.5)
180 | center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2
181 | for h1 in range(kernel_size):
182 | for h2 in range(kernel_size):
183 | position = h1+kernel_size*h2
184 | self.pos_proj.weight.data[position,2] = -1
185 | self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance
186 | self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance
187 | self.pos_proj.weight.data *= locality_strength
188 |
189 | def get_rel_indices(self, num_patches):
190 | img_size = int(num_patches**.5)
191 | rel_indices = torch.zeros(1, num_patches, num_patches, 3)
192 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
193 | indx = ind.repeat(img_size,img_size)
194 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
195 | indd = indx**2 + indy**2
196 | rel_indices[:,:,:,2] = indd.unsqueeze(0)
197 | rel_indices[:,:,:,1] = indy.unsqueeze(0)
198 | rel_indices[:,:,:,0] = indx.unsqueeze(0)
199 | device = self.qk.weight.device
200 | self.rel_indices = rel_indices.to(device)
201 |
202 |
203 | class MHSA(nn.Module):
204 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=None):
205 | super().__init__()
206 | self.num_heads = num_heads
207 | head_dim = dim // num_heads
208 | self.scale = qk_scale or head_dim ** -0.5
209 |
210 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
211 | self.attn_drop = nn.Dropout(attn_drop)
212 | self.proj = nn.Linear(dim, dim)
213 | self.proj_drop = nn.Dropout(proj_drop)
214 | self.apply(self._init_weights)
215 |
216 | def reset_parameters(self):
217 | self.apply(self._init_weights)
218 |
219 | def _init_weights(self, m):
220 | if isinstance(m, nn.Linear):
221 | trunc_normal_(m.weight, std=.02)
222 | if isinstance(m, nn.Linear) and m.bias is not None:
223 | nn.init.constant_(m.bias, 0)
224 | elif isinstance(m, nn.LayerNorm):
225 | nn.init.constant_(m.bias, 0)
226 | nn.init.constant_(m.weight, 1.0)
227 |
228 | def get_attention_map(self, x, return_map = False):
229 | B, N, C = x.shape
230 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
231 | q, k, v = qkv[0], qkv[1], qkv[2]
232 | attn_map = (q @ k.transpose(-2, -1)) * self.scale
233 | attn_map = attn_map.softmax(dim=-1).mean(0)
234 |
235 | img_size = int(N**.5)
236 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
237 | indx = ind.repeat(img_size,img_size)
238 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
239 | indd = indx**2 + indy**2
240 | distances = indd**.5
241 | distances = distances.to('cuda')
242 |
243 | dist = torch.einsum('nm,hnm->h', (distances, attn_map))
244 | dist /= N
245 |
246 | if return_map:
247 | return dist, attn_map
248 | else:
249 | return dist
250 |
251 | def forward(self, x):
252 | B, N, C = x.shape
253 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
254 | q, k, v = qkv[0], qkv[1], qkv[2]
255 |
256 | attn = (q @ k.transpose(-2, -1)) * self.scale
257 | attn = attn.softmax(dim=-1)
258 | attn = self.attn_drop(attn)
259 |
260 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
261 | x = self.proj(x)
262 | x = self.proj_drop(x)
263 | return x, attn
264 |
265 |
266 | class ScaleNorm(nn.Module):
267 | """See
268 | https://github.com/lucidrains/reformer-pytorch/blob/a751fe2eb939dcdd81b736b2f67e745dc8472a09/reformer_pytorch/reformer_pytorch.py#L143
269 | """
270 | def __init__(self, dim, eps=1e-5):
271 | super().__init__()
272 | self.g = nn.Parameter(torch.ones(1))
273 | self.eps = eps
274 |
275 | def forward(self, x):
276 | n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
277 | return x / n * self.g
278 |
279 |
280 | class Block(nn.Module):
281 |
282 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
283 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type=GPSA,
284 | fc=nn.Linear, **kwargs):
285 | super().__init__()
286 | self.norm1 = norm_layer(dim)
287 | self.attn = attention_type(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, fc=fc, **kwargs)
288 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
289 | self.norm2 = norm_layer(dim)
290 | mlp_hidden_dim = int(dim * mlp_ratio)
291 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, fc=fc)
292 |
293 | def reset_parameters(self):
294 | self.norm1.reset_parameters()
295 | self.norm2.reset_parameters()
296 | self.attn.reset_parameters()
297 | self.mlp.apply(self.mlp._init_weights)
298 |
299 | def forward(self, x, mask_heads=None, task_index=1, attn_mask=None):
300 | if isinstance(self.attn, ClassAttention): # Like in CaiT
301 | cls_token = x[:, :task_index]
302 |
303 | xx = self.norm1(x)
304 | xx, attn, v = self.attn(
305 | xx,
306 | mask_heads=mask_heads,
307 | nb=task_index,
308 | attn_mask=attn_mask
309 | )
310 |
311 | cls_token = self.drop_path(xx[:, :task_index]) + cls_token
312 | cls_token = self.drop_path(self.mlp(self.norm2(cls_token))) + cls_token
313 |
314 | return cls_token, attn, v
315 |
316 | xx = self.norm1(x)
317 | xx, attn, v = self.attn(xx)
318 |
319 | x = self.drop_path(xx) + x
320 | x = self.drop_path(self.mlp(self.norm2(x))) + x
321 |
322 | return x, attn, v
323 |
324 |
325 | class ClassAttention(nn.Module):
326 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
327 | # with slight modifications to do CA
328 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=nn.Linear):
329 | super().__init__()
330 | self.num_heads = num_heads
331 | head_dim = dim // num_heads
332 | self.scale = qk_scale or head_dim ** -0.5
333 |
334 | self.q = fc(dim, dim, bias=qkv_bias)
335 | self.k = fc(dim, dim, bias=qkv_bias)
336 | self.v = fc(dim, dim, bias=qkv_bias)
337 | self.attn_drop = nn.Dropout(attn_drop)
338 | self.proj = fc(dim, dim)
339 | self.proj_drop = nn.Dropout(proj_drop)
340 |
341 | self.apply(self._init_weights)
342 |
343 | def reset_parameters(self):
344 | self.apply(self._init_weights)
345 |
346 | def _init_weights(self, m):
347 | if isinstance(m, nn.Linear):
348 | trunc_normal_(m.weight, std=.02)
349 | if isinstance(m, nn.Linear) and m.bias is not None:
350 | nn.init.constant_(m.bias, 0)
351 | elif isinstance(m, nn.LayerNorm):
352 | nn.init.constant_(m.bias, 0)
353 | nn.init.constant_(m.weight, 1.0)
354 |
355 | def forward(self, x, mask_heads=None, **kwargs):
356 | B, N, C = x.shape
357 | q = self.q(x[:,0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
358 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
359 |
360 | q = q * self.scale
361 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
362 |
363 | attn = (q @ k.transpose(-2, -1))
364 | attn = attn.softmax(dim=-1)
365 | attn = self.attn_drop(attn)
366 |
367 | if mask_heads is not None:
368 | mask_heads = mask_heads.expand(B, self.num_heads, -1, N)
369 | attn = attn * mask_heads
370 |
371 | x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
372 | x_cls = self.proj(x_cls)
373 | x_cls = self.proj_drop(x_cls)
374 |
375 | return x_cls, attn, v
376 |
377 |
378 | class JointCA(ClassAttention):
379 | """Forward all task tokens together.
380 |
381 | It uses a masked attention so that task tokens don't interact between them.
382 | It should have the same results as independent forward per task token but being
383 | much faster.
384 |
385 | HOWEVER, it works a bit worse (like ~2pts less in 'all top-1' CIFAR100 50 steps).
386 | So if anyone knows why, please tell me!
387 | """
388 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., fc=nn.Linear):
389 | super().__init__()
390 | self.num_heads = num_heads
391 | head_dim = dim // num_heads
392 | self.scale = qk_scale or head_dim ** -0.5
393 |
394 | self.q = fc(dim, dim, bias=qkv_bias)
395 | self.k = fc(dim, dim, bias=qkv_bias)
396 | self.v = fc(dim, dim, bias=qkv_bias)
397 | self.attn_drop = nn.Dropout(attn_drop)
398 | self.proj = fc(dim, dim)
399 | self.proj_drop = nn.Dropout(proj_drop)
400 |
401 | self.apply(self._init_weights)
402 |
403 | def reset_parameters(self):
404 | self.apply(self._init_weights)
405 |
406 | def _init_weights(self, m):
407 | if isinstance(m, nn.Linear):
408 | trunc_normal_(m.weight, std=.02)
409 | if isinstance(m, nn.Linear) and m.bias is not None:
410 | nn.init.constant_(m.bias, 0)
411 | elif isinstance(m, nn.LayerNorm):
412 | nn.init.constant_(m.bias, 0)
413 | nn.init.constant_(m.weight, 1.0)
414 |
415 | @lru_cache(maxsize=1)
416 | def get_attention_mask(self, attn_shape, nb_task_tokens):
417 | """Mask so that task tokens don't interact together.
418 |
419 | Given two task tokens (t1, t2) and three patch tokens (p1, p2, p3), the
420 | attention matrix is:
421 |
422 | t1-t1 t1-t2 t1-p1 t1-p2 t1-p3
423 | t2-t1 t2-t2 t2-p1 t2-p2 t2-p3
424 |
425 | So that the mask (True values are deleted) should be:
426 |
427 | False True False False False
428 | True False False False False
429 | """
430 | mask = torch.zeros(attn_shape, dtype=torch.bool)
431 | for i in range(nb_task_tokens):
432 | mask[:, i, :i] = True
433 | mask[:, i, i+1:nb_task_tokens] = True
434 | return mask
435 |
436 | def forward(self, x, attn_mask=None, nb_task_tokens=1):
437 | B, N, C = x.shape
438 | q = self.q(x[:,:nb_task_tokens]).reshape(B, nb_task_tokens, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
439 | k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
440 |
441 | q = q * self.scale
442 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
443 |
444 | attn = (q @ k.transpose(-2, -1))
445 | if attn_mask is not None:
446 | mask = self.get_attention_mask(attn.shape, nb_task_tokens, attn_mask)
447 | attn[mask] = -float('inf')
448 | attn = attn.softmax(dim=-1)
449 |
450 | attn = self.attn_drop(attn)
451 |
452 | x_cls = (attn @ v).transpose(1, 2).reshape(B, nb_task_tokens, C)
453 | x_cls = self.proj(x_cls)
454 | x_cls = self.proj_drop(x_cls)
455 |
456 | return x_cls, attn, v
457 |
458 |
459 | class PatchEmbed(nn.Module):
460 | """ Image to Patch Embedding, from timm
461 | """
462 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
463 | super().__init__()
464 | img_size = to_2tuple(img_size)
465 | patch_size = to_2tuple(patch_size)
466 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
467 | self.img_size = img_size
468 | self.patch_size = patch_size
469 | self.num_patches = num_patches
470 |
471 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
472 | self.apply(self._init_weights)
473 |
474 | def reset_parameters(self):
475 | self.apply(self._init_weights)
476 |
477 | def _init_weights(self, m):
478 | if isinstance(m, nn.Linear):
479 | trunc_normal_(m.weight, std=.02)
480 | if isinstance(m, nn.Linear) and m.bias is not None:
481 | nn.init.constant_(m.bias, 0)
482 | elif isinstance(m, nn.LayerNorm):
483 | nn.init.constant_(m.bias, 0)
484 | nn.init.constant_(m.weight, 1.0)
485 |
486 | def forward(self, x):
487 | B, C, H, W = x.shape
488 | #assert H == self.img_size[0] and W == self.img_size[1], \
489 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
490 | x = self.proj(x).flatten(2).transpose(1, 2)
491 | return x
492 |
493 |
494 |
495 | class HybridEmbed(nn.Module):
496 | """ CNN Feature Map Embedding, from timm
497 | """
498 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
499 | super().__init__()
500 | assert isinstance(backbone, nn.Module)
501 | img_size = to_2tuple(img_size)
502 | self.img_size = img_size
503 | self.backbone = backbone
504 | if feature_size is None:
505 | with torch.no_grad():
506 | training = backbone.training
507 | if training:
508 | backbone.eval()
509 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
510 | feature_size = o.shape[-2:]
511 | feature_dim = o.shape[1]
512 | backbone.train(training)
513 | else:
514 | feature_size = to_2tuple(feature_size)
515 | feature_dim = self.backbone.feature_info.channels()[-1]
516 | self.num_patches = feature_size[0] * feature_size[1]
517 | self.proj = nn.Linear(feature_dim, embed_dim)
518 | self.apply(self._init_weights)
519 |
520 | def reset_parameters(self):
521 | self.apply(self._init_weights)
522 |
523 | def forward(self, x):
524 | x = self.backbone(x)[-1]
525 | x = x.flatten(2).transpose(1, 2)
526 | x = self.proj(x)
527 | return x
528 |
529 |
530 | class ConVit(nn.Module):
531 | """ Vision Transformer with support for patch or hybrid CNN input stage
532 | """
533 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
534 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
535 | drop_path_rate=0., hybrid_backbone=None, norm_layer='layer',
536 | local_up_to_layer=3, locality_strength=1., use_pos_embed=True,
537 | class_attention=False, ca_type='base',
538 | ):
539 | super().__init__()
540 | self.num_classes = num_classes
541 | self.num_heads = num_heads
542 | self.embed_dim = embed_dim
543 | self.local_up_to_layer = local_up_to_layer
544 | self.num_features = self.final_dim = self.embed_dim = embed_dim # num_features for consistency with other models_2
545 | self.locality_strength = locality_strength
546 | self.use_pos_embed = use_pos_embed
547 |
548 | if norm_layer == 'layer':
549 | norm_layer = nn.LayerNorm
550 | elif norm_layer == 'scale':
551 | norm_layer = ScaleNorm
552 | else:
553 | raise NotImplementedError(f'Unknown normalization {norm_layer}')
554 |
555 | if hybrid_backbone is not None:
556 | self.patch_embed = HybridEmbed(
557 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
558 | else:
559 | self.patch_embed = PatchEmbed(
560 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
561 | num_patches = self.patch_embed.num_patches
562 | self.num_patches = num_patches
563 |
564 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
565 | self.pos_drop = nn.Dropout(p=drop_rate)
566 |
567 | if self.use_pos_embed:
568 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
569 | trunc_normal_(self.pos_embed, std=.02)
570 |
571 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
572 |
573 | blocks = []
574 |
575 | if ca_type == 'base':
576 | ca_block = ClassAttention
577 | elif ca_type == 'jointca':
578 | ca_block = JointCA
579 | else:
580 | raise ValueError(f'Unknown CA type {ca_type}')
581 |
582 | for layer_index in range(depth):
583 | if layer_index < local_up_to_layer:
584 | # Convit
585 | block = Block(
586 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
587 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer,
588 | attention_type=GPSA, locality_strength=locality_strength
589 | )
590 | elif not class_attention:
591 | # Convit
592 | block = Block(
593 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
594 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer,
595 | attention_type=MHSA
596 | )
597 | else:
598 | # CaiT
599 | block = Block(
600 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
601 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[layer_index], norm_layer=norm_layer,
602 | attention_type=ca_block
603 | )
604 |
605 | blocks.append(block)
606 |
607 | self.blocks = nn.ModuleList(blocks)
608 | self.norm = norm_layer(embed_dim)
609 | self.use_class_attention = class_attention
610 |
611 | # Classifier head
612 | self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
613 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
614 |
615 | trunc_normal_(self.cls_token, std=.02)
616 | self.head.apply(self._init_weights)
617 |
618 | def freeze(self, names):
619 | for name in names:
620 | if name == 'all':
621 | return freeze_parameters(self)
622 | elif name == 'old_heads':
623 | self.head.freeze(name)
624 | elif name == 'backbone':
625 | freeze_parameters(self.blocks)
626 | freeze_parameters(self.patch_embed)
627 | freeze_parameters(self.pos_embed)
628 | freeze_parameters(self.norm)
629 | else:
630 | raise NotImplementedError(f'Unknown name={name}.')
631 |
632 | def reset_classifier(self):
633 | self.head.apply(self._init_weights)
634 |
635 | def reset_parameters(self):
636 | for b in self.blocks:
637 | b.reset_parameters()
638 | self.norm.reset_parameters()
639 | self.head.apply(self._init_weights)
640 |
641 | def _init_weights(self, m):
642 | if isinstance(m, nn.Linear):
643 | trunc_normal_(m.weight, std=.02)
644 | if isinstance(m, nn.Linear) and m.bias is not None:
645 | nn.init.constant_(m.bias, 0)
646 | elif isinstance(m, nn.LayerNorm):
647 | nn.init.constant_(m.bias, 0)
648 | nn.init.constant_(m.weight, 1.0)
649 |
650 | def get_internal_losses(self, clf_loss):
651 | return {}
652 |
653 | def end_finetuning(self):
654 | pass
655 |
656 | def begin_finetuning(self):
657 | pass
658 |
659 | def epoch_log(self):
660 | return {}
661 |
662 | @torch.jit.ignore
663 | def no_weight_decay(self):
664 | return {'pos_embed', 'cls_token'}
665 |
666 | def get_classifier(self):
667 | return self.head
668 |
669 | def forward_sa(self, x):
670 | B = x.shape[0]
671 | x = self.patch_embed(x)
672 |
673 | if self.use_pos_embed:
674 | x = x + self.pos_embed
675 | x = self.pos_drop(x)
676 |
677 | for blk in self.blocks[:self.local_up_to_layer]:
678 | x, _ = blk(x)
679 |
680 | return x
681 |
682 | def forward_features(self, x, final_norm=True):
683 | B = x.shape[0]
684 | x = self.patch_embed(x)
685 |
686 | cls_tokens = self.cls_token.expand(B, -1, -1)
687 |
688 | if self.use_pos_embed:
689 | x = x + self.pos_embed
690 | x = self.pos_drop(x)
691 |
692 | for blk in self.blocks[:self.local_up_to_layer]:
693 | x, _, _ = blk(x)
694 |
695 | if self.use_class_attention:
696 | for blk in self.blocks[self.local_up_to_layer:]:
697 | cls_tokens, _, _ = blk(torch.cat((cls_tokens, x), dim=1))
698 | else:
699 | x = torch.cat((cls_tokens, x), dim=1)
700 | for blk in self.blocks[self.local_up_to_layer:]:
701 | x, _ , _ = blk(x)
702 |
703 | if final_norm:
704 | if self.use_class_attention:
705 | cls_tokens = self.norm(cls_tokens)
706 | else:
707 | x = self.norm(x)
708 |
709 | if self.use_class_attention:
710 | return cls_tokens[:, 0], None, None
711 | else:
712 | return x[:, 0], None, None
713 |
714 | def forward(self, x):
715 | x = self.forward_features(x)[0]
716 | x = self.head(x)
717 | return x
718 |
719 |
720 |
--------------------------------------------------------------------------------
/models/vit.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 | A PyTorch implement of Vision Transformers as described in:
3 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
4 | - https://arxiv.org/abs/2010.11929
5 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
6 | - https://arxiv.org/abs/2106.10270
7 | The official jax code is released and available at https://github.com/google-research/vision_transformer
8 | Acknowledgments:
9 | * The paper authors for releasing code and weights, thanks!
10 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
11 | for some einops/einsum fun
12 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
13 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
14 | Hacked together by / Copyright 2020, Ross Wightman
15 | """
16 | import math
17 | import logging
18 | from functools import partial
19 | from collections import OrderedDict
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | import torch.utils.checkpoint
25 |
26 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
27 | from timm.models.helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
28 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
29 | from timm.models.registry import register_model
30 |
31 | _logger = logging.getLogger(__name__)
32 |
33 |
34 | def _cfg(url='', **kwargs):
35 | return {
36 | 'url': url,
37 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
38 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
39 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
40 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
41 | **kwargs
42 | }
43 |
44 |
45 | default_cfgs = {
46 | # patch engine (weights from official Google JAX impl)
47 | 'vit_tiny_patch16_224': _cfg(
48 | url='https://storage.googleapis.com/vit_models/augreg/'
49 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
50 | 'vit_tiny_patch16_384': _cfg(
51 | url='https://storage.googleapis.com/vit_models/augreg/'
52 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
53 | input_size=(3, 384, 384), crop_pct=1.0),
54 | 'vit_small_patch32_224': _cfg(
55 | url='https://storage.googleapis.com/vit_models/augreg/'
56 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
57 | 'vit_small_patch32_384': _cfg(
58 | url='https://storage.googleapis.com/vit_models/augreg/'
59 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
60 | input_size=(3, 384, 384), crop_pct=1.0),
61 | 'vit_small_patch16_224': _cfg(
62 | url='https://storage.googleapis.com/vit_models/augreg/'
63 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
64 | 'vit_small_patch16_384': _cfg(
65 | url='https://storage.googleapis.com/vit_models/augreg/'
66 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
67 | input_size=(3, 384, 384), crop_pct=1.0),
68 | 'vit_base_patch32_224': _cfg(
69 | url='https://storage.googleapis.com/vit_models/augreg/'
70 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
71 | 'vit_base_patch32_384': _cfg(
72 | url='https://storage.googleapis.com/vit_models/augreg/'
73 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
74 | input_size=(3, 384, 384), crop_pct=1.0),
75 | 'vit_base_patch16_224': _cfg(
76 | url='https://storage.googleapis.com/vit_models/augreg/'
77 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
78 | 'vit_base_patch16_384': _cfg(
79 | url='https://storage.googleapis.com/vit_models/augreg/'
80 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
81 | input_size=(3, 384, 384), crop_pct=1.0),
82 | 'vit_base_patch8_224': _cfg(
83 | url='https://storage.googleapis.com/vit_models/augreg/'
84 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
85 | 'vit_large_patch32_224': _cfg(
86 | url='', # no official model weights for this combo, only for in21k
87 | ),
88 | 'vit_large_patch32_384': _cfg(
89 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
90 | input_size=(3, 384, 384), crop_pct=1.0),
91 | 'vit_large_patch16_224': _cfg(
92 | url='https://storage.googleapis.com/vit_models/augreg/'
93 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
94 | 'vit_large_patch16_384': _cfg(
95 | url='https://storage.googleapis.com/vit_models/augreg/'
96 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
97 | input_size=(3, 384, 384), crop_pct=1.0),
98 |
99 | 'vit_large_patch14_224': _cfg(url=''),
100 | 'vit_huge_patch14_224': _cfg(url=''),
101 | 'vit_giant_patch14_224': _cfg(url=''),
102 | 'vit_gigantic_patch14_224': _cfg(url=''),
103 |
104 | 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
105 |
106 | # patch engine, imagenet21k (weights from official Google JAX impl)
107 | 'vit_tiny_patch16_224_in21k': _cfg(
108 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
109 | num_classes=21843),
110 | 'vit_small_patch32_224_in21k': _cfg(
111 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
112 | num_classes=21843),
113 | 'vit_small_patch16_224_in21k': _cfg(
114 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
115 | num_classes=21843),
116 | 'vit_base_patch32_224_in21k': _cfg(
117 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
118 | num_classes=21843),
119 | 'vit_base_patch16_224_in21k': _cfg(
120 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
121 | num_classes=21843),
122 | 'vit_base_patch8_224_in21k': _cfg(
123 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
124 | num_classes=21843),
125 | 'vit_large_patch32_224_in21k': _cfg(
126 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
127 | num_classes=21843),
128 | 'vit_large_patch16_224_in21k': _cfg(
129 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
130 | num_classes=21843),
131 | 'vit_huge_patch14_224_in21k': _cfg(
132 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
133 | hf_hub_id='timm/vit_huge_patch14_224_in21k',
134 | num_classes=21843),
135 |
136 | # SAM trained engine (https://arxiv.org/abs/2106.01548)
137 | 'vit_base_patch32_224_sam': _cfg(
138 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
139 | 'vit_base_patch16_224_sam': _cfg(
140 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
141 |
142 | # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
143 | 'vit_small_patch16_224_dino': _cfg(
144 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
145 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
146 | 'vit_small_patch8_224_dino': _cfg(
147 | url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
148 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
149 | 'vit_base_patch16_224_dino': _cfg(
150 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
151 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
152 | 'vit_base_patch8_224_dino': _cfg(
153 | url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
154 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
155 |
156 |
157 | # ViT ImageNet-21K-P pretraining by MILL
158 | 'vit_base_patch16_224_miil_in21k': _cfg(
159 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
160 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
161 | ),
162 | 'vit_base_patch16_224_miil': _cfg(
163 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
164 | '/vit_base_patch16_224_1k_miil_84_4.pth',
165 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
166 | ),
167 |
168 | # experimental
169 | 'vit_small_patch16_36x1_224': _cfg(url=''),
170 | 'vit_small_patch16_18x2_224': _cfg(url=''),
171 | 'vit_base_patch16_18x2_224': _cfg(url=''),
172 | }
173 |
174 |
175 | class Attention(nn.Module):
176 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
177 | super().__init__()
178 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
179 | self.num_heads = num_heads
180 | head_dim = dim // num_heads
181 | self.scale = head_dim ** -0.5
182 |
183 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
184 | self.attn_drop = nn.Dropout(attn_drop)
185 | self.proj = nn.Linear(dim, dim)
186 | self.proj_drop = nn.Dropout(proj_drop)
187 |
188 | def forward(self, x):
189 | B, N, C = x.shape
190 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
191 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
192 |
193 | attn = (q @ k.transpose(-2, -1)) * self.scale
194 | attn = attn.softmax(dim=-1)
195 | attn = self.attn_drop(attn)
196 |
197 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
198 | x = self.proj(x)
199 | x = self.proj_drop(x)
200 | return x
201 |
202 |
203 | class LayerScale(nn.Module):
204 | def __init__(self, dim, init_values=1e-5, inplace=False):
205 | super().__init__()
206 | self.inplace = inplace
207 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
208 |
209 | def forward(self, x):
210 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
211 |
212 |
213 | class Block(nn.Module):
214 |
215 | def __init__(
216 | self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
217 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
218 | super().__init__()
219 | self.norm1 = norm_layer(dim)
220 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
221 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
222 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
223 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
224 |
225 | self.norm2 = norm_layer(dim)
226 | mlp_hidden_dim = int(dim * mlp_ratio)
227 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
228 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
229 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
230 |
231 | def forward(self, x):
232 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
233 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
234 | return x
235 |
236 |
237 | class ParallelBlock(nn.Module):
238 |
239 | def __init__(
240 | self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
241 | drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
242 | super().__init__()
243 | self.num_parallel = num_parallel
244 | self.attns = nn.ModuleList()
245 | self.ffns = nn.ModuleList()
246 | for _ in range(num_parallel):
247 | self.attns.append(nn.Sequential(OrderedDict([
248 | ('norm', norm_layer(dim)),
249 | ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
250 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
251 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
252 | ])))
253 | self.ffns.append(nn.Sequential(OrderedDict([
254 | ('norm', norm_layer(dim)),
255 | ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
256 | ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
257 | ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
258 | ])))
259 |
260 | def _forward_jit(self, x):
261 | x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
262 | x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
263 | return x
264 |
265 | @torch.jit.ignore
266 | def _forward(self, x):
267 | x = x + sum(attn(x) for attn in self.attns)
268 | x = x + sum(ffn(x) for ffn in self.ffns)
269 | return x
270 |
271 | def forward(self, x):
272 | if torch.jit.is_scripting() or torch.jit.is_tracing():
273 | return self._forward_jit(x)
274 | else:
275 | return self._forward(x)
276 |
277 |
278 | class VisionTransformer(nn.Module):
279 | """ Vision Transformer
280 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
281 | - https://arxiv.org/abs/2010.11929
282 | """
283 |
284 | def __init__(
285 | self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
286 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
287 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
288 | embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
289 | """
290 | Args:
291 | img_size (int, tuple): input image size
292 | patch_size (int, tuple): patch size
293 | in_chans (int): number of input channels
294 | num_classes (int): number of classes for classification head
295 | global_pool (str): type of global pooling for final sequence (default: 'token')
296 | embed_dim (int): embedding dimension
297 | depth (int): depth of transformer
298 | num_heads (int): number of attention heads
299 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
300 | qkv_bias (bool): enable bias for qkv if True
301 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
302 | drop_rate (float): dropout rate
303 | attn_drop_rate (float): attention dropout rate
304 | drop_path_rate (float): stochastic depth rate
305 | weight_init: (str): weight init scheme
306 | init_values: (float): layer-scale init values
307 | embed_layer (nn.Module): patch embedding layer
308 | norm_layer: (nn.Module): normalization layer
309 | act_layer: (nn.Module): MLP activation layer
310 | """
311 | super().__init__()
312 | assert global_pool in ('', 'avg', 'token')
313 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
314 | act_layer = act_layer or nn.GELU
315 |
316 | self.num_classes = num_classes
317 | self.global_pool = global_pool
318 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other engine
319 | self.num_tokens = 1
320 | self.grad_checkpointing = False
321 |
322 | self.patch_embed = embed_layer(
323 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
324 | num_patches = self.patch_embed.num_patches
325 |
326 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
327 | # self.cls_token_grow = nn.Parameter(torch.zeros(1, 5000, embed_dim))
328 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
329 | # self.pos_embed_grow = nn.Parameter(torch.zeros(1, num_patches + 1000, embed_dim))
330 | self.pos_drop = nn.Dropout(p=drop_rate)
331 |
332 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
333 | self.blocks = nn.Sequential(*[
334 | block_fn(
335 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
336 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
337 | for i in range(depth)])
338 | use_fc_norm = self.global_pool == 'avg'
339 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
340 |
341 | # Representation layer. Used for original ViT engine w/ in21k pretraining.
342 | self.representation_size = representation_size
343 | self.pre_logits = nn.Identity()
344 | if representation_size:
345 | self._reset_representation(representation_size)
346 |
347 | # Classifier Head
348 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
349 | final_chs = self.representation_size if self.representation_size else self.embed_dim
350 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
351 | self.out_dim = final_chs
352 |
353 | if weight_init != 'skip':
354 | self.init_weights(weight_init)
355 |
356 | def _reset_representation(self, representation_size):
357 | self.representation_size = representation_size
358 | if self.representation_size:
359 | self.pre_logits = nn.Sequential(OrderedDict([
360 | ('fc', nn.Linear(self.embed_dim, self.representation_size)),
361 | ('act', nn.Tanh())
362 | ]))
363 | else:
364 | self.pre_logits = nn.Identity()
365 |
366 | def init_weights(self, mode=''):
367 | assert mode in ('jax', 'jax_nlhb', 'moco', '')
368 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
369 | trunc_normal_(self.pos_embed, std=.02)
370 | # trunc_normal_(self.pos_embed_grow, std=.02)
371 | nn.init.normal_(self.cls_token, std=1e-6)
372 | # nn.init.normal_(self.cls_token_grow, std=1e-6)
373 | named_apply(get_init_weights_vit(mode, head_bias), self)
374 |
375 | def _init_weights(self, m):
376 | # this fn left here for compat with downstream users
377 | init_weights_vit_timm(m)
378 |
379 | @torch.jit.ignore()
380 | def load_pretrained(self, checkpoint_path, prefix=''):
381 | _load_weights(self, checkpoint_path, prefix)
382 |
383 | @torch.jit.ignore
384 | def no_weight_decay(self):
385 | return {'pos_embed', 'cls_token', 'dist_token'}
386 |
387 | @torch.jit.ignore
388 | def group_matcher(self, coarse=False):
389 | return dict(
390 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
391 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
392 | )
393 |
394 | @torch.jit.ignore
395 | def set_grad_checkpointing(self, enable=True):
396 | self.grad_checkpointing = enable
397 |
398 | @torch.jit.ignore
399 | def get_classifier(self):
400 | return self.head
401 |
402 | def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None):
403 | self.num_classes = num_classes
404 | if global_pool is not None:
405 | assert global_pool in ('', 'avg', 'token')
406 | self.global_pool = global_pool
407 | if representation_size is not None:
408 | self._reset_representation(representation_size)
409 | final_chs = self.representation_size if self.representation_size else self.embed_dim
410 | self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
411 |
412 | def forward_features(self, x):
413 | x = self.patch_embed(x)
414 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
415 |
416 | x = self.pos_drop(x + self.pos_embed)
417 | if self.grad_checkpointing and not torch.jit.is_scripting():
418 | x = checkpoint_seq(self.blocks, x)
419 | else:
420 | x = self.blocks(x)
421 | x = self.norm(x)
422 | return x
423 |
424 | def forward_features_grow(self, x, class_num):
425 | x = self.patch_embed(x)
426 | # x = torch.cat((self.cls_token_grow[:, :class_num, :].expand(x.shape[0], -1, -1), self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
427 | # x = self.pos_drop(x + self.pos_embed_grow[:, :self.patch_embed.num_patches+class_num, :])
428 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
429 | x = self.pos_drop(x + self.pos_embed)
430 | x = torch.cat((self.cls_token_grow[:, :class_num*2, :].expand(x.shape[0], -1, -1), x), dim=1)
431 |
432 | # import pdb;pdb.set_trace()
433 | if self.grad_checkpointing and not torch.jit.is_scripting():
434 | x = checkpoint_seq(self.blocks, x)
435 | else:
436 | x = self.blocks(x)
437 | x = self.norm(x)
438 | return x
439 |
440 | def forward_head(self, x, pre_logits: bool = False):
441 | if self.global_pool:
442 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
443 | x = self.fc_norm(x)
444 | x = self.pre_logits(x)
445 | return x if pre_logits else self.head(x)
446 |
447 | def forward(self, x, grow_flag=False, numcls=0):
448 | if not grow_flag:
449 | x = self.forward_features(x)
450 | else:
451 | x = self.forward_features_grow(x, numcls)
452 |
453 | if self.global_pool:
454 | x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
455 | x = self.fc_norm(x)
456 | return {
457 | 'fmaps': [x],
458 | 'features': x
459 | }
460 |
461 |
462 | def init_weights_vit_timm(module: nn.Module, name: str = ''):
463 | """ ViT weight initialization, original timm impl (for reproducibility) """
464 | if isinstance(module, nn.Linear):
465 | trunc_normal_(module.weight, std=.02)
466 | if module.bias is not None:
467 | nn.init.zeros_(module.bias)
468 |
469 |
470 | def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
471 | """ ViT weight initialization, matching JAX (Flax) impl """
472 | if isinstance(module, nn.Linear):
473 | if name.startswith('head'):
474 | nn.init.zeros_(module.weight)
475 | nn.init.constant_(module.bias, head_bias)
476 | elif name.startswith('pre_logits'):
477 | lecun_normal_(module.weight)
478 | nn.init.zeros_(module.bias)
479 | else:
480 | nn.init.xavier_uniform_(module.weight)
481 | if module.bias is not None:
482 | nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
483 | elif isinstance(module, nn.Conv2d):
484 | lecun_normal_(module.weight)
485 | if module.bias is not None:
486 | nn.init.zeros_(module.bias)
487 |
488 |
489 | def init_weights_vit_moco(module: nn.Module, name: str = ''):
490 | """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
491 | if isinstance(module, nn.Linear):
492 | if 'qkv' in name:
493 | # treat the weights of Q, K, V separately
494 | val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
495 | nn.init.uniform_(module.weight, -val, val)
496 | else:
497 | nn.init.xavier_uniform_(module.weight)
498 | if module.bias is not None:
499 | nn.init.zeros_(module.bias)
500 |
501 |
502 | def get_init_weights_vit(mode='jax', head_bias: float = 0.):
503 | if 'jax' in mode:
504 | return partial(init_weights_vit_jax, head_bias=head_bias)
505 | elif 'moco' in mode:
506 | return init_weights_vit_moco
507 | else:
508 | return init_weights_vit_timm
509 |
510 |
511 | @torch.no_grad()
512 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
513 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
514 | """
515 | import numpy as np
516 |
517 | def _n2p(w, t=True):
518 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
519 | w = w.flatten()
520 | if t:
521 | if w.ndim == 4:
522 | w = w.transpose([3, 2, 0, 1])
523 | elif w.ndim == 3:
524 | w = w.transpose([2, 0, 1])
525 | elif w.ndim == 2:
526 | w = w.transpose([1, 0])
527 | return torch.from_numpy(w)
528 |
529 | w = np.load(checkpoint_path)
530 | if not prefix and 'opt/target/embedding/kernel' in w:
531 | prefix = 'opt/target/'
532 |
533 | if hasattr(model.patch_embed, 'backbone'):
534 | # hybrid
535 | backbone = model.patch_embed.backbone
536 | stem_only = not hasattr(backbone, 'stem')
537 | stem = backbone if stem_only else backbone.stem
538 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
539 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
540 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
541 | if not stem_only:
542 | for i, stage in enumerate(backbone.stages):
543 | for j, block in enumerate(stage.blocks):
544 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
545 | for r in range(3):
546 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
547 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
548 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
549 | if block.downsample is not None:
550 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
551 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
552 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
553 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
554 | else:
555 | embed_conv_w = adapt_input_conv(
556 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
557 | model.patch_embed.proj.weight.copy_(embed_conv_w)
558 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
559 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
560 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
561 | if pos_embed_w.shape != model.pos_embed.shape:
562 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
563 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
564 | model.pos_embed.copy_(pos_embed_w)
565 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
566 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
567 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
568 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
569 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
570 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
571 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
572 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
573 | for i, block in enumerate(model.blocks.children()):
574 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
575 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
576 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
577 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
578 | block.attn.qkv.weight.copy_(torch.cat([
579 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
580 | block.attn.qkv.bias.copy_(torch.cat([
581 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
582 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
583 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
584 | for r in range(2):
585 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
586 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
587 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
588 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
589 |
590 |
591 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
592 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
593 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
594 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
595 | ntok_new = posemb_new.shape[1]
596 | if num_tokens:
597 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
598 | ntok_new -= num_tokens
599 | else:
600 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
601 | gs_old = int(math.sqrt(len(posemb_grid)))
602 | if not len(gs_new): # backwards compatibility
603 | gs_new = [int(math.sqrt(ntok_new))] * 2
604 | assert len(gs_new) >= 2
605 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
606 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
607 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
608 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
609 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
610 | return posemb
611 |
612 |
613 | def checkpoint_filter_fn(state_dict, model):
614 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
615 | out_dict = {}
616 | if 'model' in state_dict:
617 | # For deit engine
618 | state_dict = state_dict['model']
619 | for k, v in state_dict.items():
620 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
621 | # For old engine that I trained prior to conv based patchification
622 | O, I, H, W = model.patch_embed.proj.weight.shape
623 | v = v.reshape(O, -1, H, W)
624 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
625 | # To resize pos embedding when using model at different size from pretrained weights
626 | v = resize_pos_embed(
627 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
628 | out_dict[k] = v
629 | return out_dict
630 |
631 |
632 | def _create_vision_transformer(variant, pretrained=False, **kwargs):
633 | if kwargs.get('features_only', None):
634 | raise RuntimeError('features_only not implemented for Vision Transformer engine.')
635 |
636 | # NOTE this extra code to support handling of repr size for in21k pretrained engine
637 | pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
638 | default_num_classes = pretrained_cfg['num_classes']
639 | num_classes = kwargs.get('num_classes', default_num_classes)
640 | repr_size = kwargs.pop('representation_size', None)
641 | if repr_size is not None and num_classes != default_num_classes:
642 | # Remove representation layer if fine-tuning. This may not always be the desired action,
643 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
644 | _logger.warning("Removing representation layer for fine-tuning.")
645 | repr_size = None
646 |
647 | model = build_model_with_cfg(
648 | VisionTransformer, variant, pretrained,
649 | pretrained_cfg=pretrained_cfg,
650 | representation_size=repr_size,
651 | pretrained_filter_fn=checkpoint_filter_fn,
652 | pretrained_custom_load='npz' in pretrained_cfg['url'],
653 | **kwargs)
654 | return model
655 |
656 |
657 | @register_model
658 | def vit_tiny_patch16_224(pretrained=False, **kwargs):
659 | """ ViT-Tiny (Vit-Ti/16)
660 | """
661 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
662 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
663 | return model
664 |
665 |
666 | @register_model
667 | def vit_tiny_patch16_384(pretrained=False, **kwargs):
668 | """ ViT-Tiny (Vit-Ti/16) @ 384x384.
669 | """
670 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
671 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
672 | return model
673 |
674 |
675 | @register_model
676 | def vit_small_patch32_224(pretrained=False, **kwargs):
677 | """ ViT-Small (ViT-S/32)
678 | """
679 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
680 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
681 | return model
682 |
683 |
684 | @register_model
685 | def vit_small_patch32_384(pretrained=False, **kwargs):
686 | """ ViT-Small (ViT-S/32) at 384x384.
687 | """
688 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
689 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
690 | return model
691 |
692 |
693 | @register_model
694 | def vit_small_patch16_224(pretrained=False, **kwargs):
695 | """ ViT-Small (ViT-S/16)
696 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
697 | """
698 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
699 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
700 | return model
701 |
702 |
703 | @register_model
704 | def vit_small_patch16_384(pretrained=False, **kwargs):
705 | """ ViT-Small (ViT-S/16)
706 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
707 | """
708 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
709 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
710 | return model
711 |
712 |
713 | @register_model
714 | def vit_base_patch32_224(pretrained=False, **kwargs):
715 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
716 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
717 | """
718 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
719 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
720 | return model
721 |
722 |
723 | @register_model
724 | def vit_base2_patch32_256(pretrained=False, **kwargs):
725 | """ ViT-Base (ViT-B/32)
726 | # FIXME experiment
727 | """
728 | model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
729 | model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
730 | return model
731 |
732 |
733 | @register_model
734 | def vit_base_patch32_384(pretrained=False, **kwargs):
735 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
736 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
737 | """
738 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
739 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
740 | return model
741 |
742 |
743 | @register_model
744 | def vit_base_patch16_224(pretrained=False, **kwargs):
745 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
746 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
747 | """
748 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
749 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
750 | return model
751 |
752 |
753 | @register_model
754 | def vit_base_patch16_384(pretrained=False, **kwargs):
755 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
756 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
757 | """
758 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
759 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
760 | return model
761 |
762 |
763 | @register_model
764 | def vit_base_patch8_224(pretrained=False, **kwargs):
765 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
766 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
767 | """
768 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
769 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
770 | return model
771 |
772 |
773 | @register_model
774 | def vit_large_patch32_224(pretrained=False, **kwargs):
775 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
776 | """
777 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
778 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
779 | return model
780 |
781 |
782 | @register_model
783 | def vit_large_patch32_384(pretrained=False, **kwargs):
784 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
785 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
786 | """
787 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
788 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
789 | return model
790 |
791 |
792 | @register_model
793 | def vit_large_patch16_224(pretrained=False, **kwargs):
794 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
795 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
796 | """
797 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
798 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
799 | return model
800 |
801 |
802 | @register_model
803 | def vit_large_patch16_384(pretrained=False, **kwargs):
804 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
805 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
806 | """
807 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
808 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
809 | return model
810 |
811 |
812 | @register_model
813 | def vit_large_patch14_224(pretrained=False, **kwargs):
814 | """ ViT-Large model (ViT-L/14)
815 | """
816 | model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
817 | model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
818 | return model
819 |
820 |
821 | @register_model
822 | def vit_huge_patch14_224(pretrained=False, **kwargs):
823 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
824 | """
825 | model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
826 | model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
827 | return model
828 |
829 |
830 | @register_model
831 | def vit_giant_patch14_224(pretrained=False, **kwargs):
832 | """ ViT-Giant model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
833 | """
834 | model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
835 | model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
836 | return model
837 |
838 |
839 | @register_model
840 | def vit_gigantic_patch14_224(pretrained=False, **kwargs):
841 | """ ViT-Gigantic model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
842 | """
843 | model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
844 | model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
845 | return model
846 |
847 |
848 | @register_model
849 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
850 | """ ViT-Tiny (Vit-Ti/16).
851 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
852 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
853 | """
854 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
855 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
856 | return model
857 |
858 |
859 | @register_model
860 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
861 | """ ViT-Small (ViT-S/16)
862 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
863 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
864 | """
865 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
866 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
867 | return model
868 |
869 |
870 | @register_model
871 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
872 | """ ViT-Small (ViT-S/16)
873 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
874 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
875 | """
876 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
877 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
878 | return model
879 |
880 |
881 | @register_model
882 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
883 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
884 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
885 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
886 | """
887 | model_kwargs = dict(
888 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
889 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
890 | return model
891 |
892 |
893 | @register_model
894 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
895 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
896 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
897 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
898 | """
899 | model_kwargs = dict(
900 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
901 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
902 | return model
903 |
904 |
905 | @register_model
906 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
907 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
908 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
909 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
910 | """
911 | model_kwargs = dict(
912 | patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
913 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
914 | return model
915 |
916 |
917 | @register_model
918 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
919 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
920 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
921 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
922 | """
923 | model_kwargs = dict(
924 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
925 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
926 | return model
927 |
928 |
929 | @register_model
930 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
931 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
932 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
933 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
934 | """
935 | model_kwargs = dict(
936 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
937 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
938 | return model
939 |
940 |
941 | @register_model
942 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
943 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
944 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
945 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
946 | """
947 | model_kwargs = dict(
948 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
949 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
950 | return model
951 |
952 |
953 | @register_model
954 | def vit_base_patch16_224_sam(pretrained=False, **kwargs):
955 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
956 | """
957 | # NOTE original SAM weights release worked with representation_size=768
958 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
959 | model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
960 | return model
961 |
962 |
963 | @register_model
964 | def vit_base_patch32_224_sam(pretrained=False, **kwargs):
965 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
966 | """
967 | # NOTE original SAM weights release worked with representation_size=768
968 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
969 | model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
970 | return model
971 |
972 |
973 | @register_model
974 | def vit_small_patch16_224_dino(pretrained=False, **kwargs):
975 | """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
976 | """
977 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
978 | model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs)
979 | return model
980 |
981 |
982 | @register_model
983 | def vit_small_patch8_224_dino(pretrained=False, **kwargs):
984 | """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
985 | """
986 | model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
987 | model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs)
988 | return model
989 |
990 |
991 | @register_model
992 | def vit_base_patch16_224_dino(pretrained=False, **kwargs):
993 | """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
994 | """
995 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
996 | model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs)
997 | return model
998 |
999 |
1000 | @register_model
1001 | def vit_base_patch8_224_dino(pretrained=False, **kwargs):
1002 | """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
1003 | """
1004 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
1005 | model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs)
1006 | return model
1007 |
1008 |
1009 | @register_model
1010 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
1011 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1012 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
1013 | """
1014 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
1015 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
1016 | return model
1017 |
1018 |
1019 | @register_model
1020 | def vit_base_patch16_224_miil(pretrained=False, **kwargs):
1021 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
1022 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
1023 | """
1024 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
1025 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
1026 | return model
1027 |
1028 |
1029 | @register_model
1030 | def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
1031 | """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
1032 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1033 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1034 | """
1035 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
1036 | model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
1037 | return model
1038 |
1039 |
1040 | @register_model
1041 | def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
1042 | """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1043 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1044 | Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
1045 | """
1046 | model_kwargs = dict(
1047 | patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1048 | model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1049 | return model
1050 |
1051 |
1052 | @register_model
1053 | def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
1054 | """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
1055 | Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
1056 | """
1057 | model_kwargs = dict(
1058 | patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
1059 | model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
1060 | return model
--------------------------------------------------------------------------------