├── LICENSE ├── README.md ├── __init__.py ├── augment.py ├── data ├── JigsawLoader.py ├── concat_dataset.py ├── data_helper.py ├── datasets.py └── samplers.py ├── datasets.py ├── engine_dg.py ├── hubconf.py ├── losses.py ├── main_dg.py ├── models ├── csm_triton.py ├── mamba_simple.py ├── mamba_ssm │ ├── mamba_simple_vim.py │ └── selective_scan_interface.py └── vmamba.py ├── models_mamba.py ├── perturb_style ├── ALOFT.py ├── DSU.py ├── MixStyle.py └── SeqTokenAug.py ├── samplers.py ├── scripts ├── START-M-VMamba.sh ├── START-Vim-S.sh ├── START-Vim-T.sh ├── START-X-VMamba.sh └── test_model_performance.sh ├── text_complexity.py ├── utils.py └── vim_requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation [NeurIPS 2024] 3 | 4 | 5 | ## Environments for Training 6 | 7 | - Python 3.10.13 8 | 9 | - `conda create -n your_env_name python=3.10.13` 10 | 11 | - torch 2.1.1 + cu118 12 | - `pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118` 13 | 14 | - Requirements: vim_requirements.txt 15 | - `pip install -r vim/vim_requirements.txt` 16 | 17 | - Install ``causal_conv1d`` and ``mamba`` 18 | - `pip install -e causal_conv1d>=1.1.0` 19 | - `pip install -e mamba-1p1p1` 20 | 21 | ## DataSets 22 | Please download PACS dataset from [here](https://drive.google.com/drive/folders/0B6x7gtvErXgfUU1WcGY5SzdwZVk?resourcekey=0-2fvpQY_QSyJf2uIECzqPuQ). 23 | Make sure you use the official train/val/test split in [PACS paper](https://openaccess.thecvf.com/content_iccv_2017/html/Li_Deeper_Broader_and_ICCV_2017_paper.html). 24 | Take `/data/DataSets/` as the saved directory for example: 25 | ``` 26 | images -> /data/DataSets/PACS/kfold/art_painting/dog/pic_001.jpg, ... 27 | splits -> /data/DataSets/PACS/pacs_label/art_painting_crossval_kfold.txt, ... 28 | ``` 29 | Then set the `"data_root"` as `"/data/DataSets/"` and `"data"` as `"PACS"` in `"main_dg.py"`. 30 | 31 | You can directly set the `"data_root"` and `"data"` in `"ft-vmamba-t.sh"` for training the model. 32 | 33 | 34 | 35 | 36 | 37 | ## Training 38 | 39 | 40 | Firstly download the VMamba-T model pretrained on ImageNet from [here](https://github.com/MzeroMiko/VMamba/releases/download/%2320240218/vssmtiny_dp01_ckpt_epoch_292.pth) and save it to `/pretrained_model`. To run START-M, you could run the following code. Please set the `--data_root` argument needs to be changed according to your folder. 41 | 42 | ``` 43 | base scripts/START-M.sh 44 | ``` 45 | 46 | You can also train the START-X model by running the following code: 47 | 48 | ``` 49 | base scripts/START-X.sh 50 | ``` 51 | 52 | 53 | ## Evaluation 54 | 55 | To evaluate the performance of the models, you can download the models trained on PACS as below: 56 | 57 | Methods | Photo | Art | Cartoon | Sketch | Avg. | 58 | :----: | :----: | :----: | :----: | :----: | :----: | 59 | START-M | 99.22 | 93.95 | 87.84 | 87.68 | [92.17](https://drive.google.com/drive/folders/1kSF1mK-xwpLb0SKct-_92dxjcg9YN95x?usp=sharing) | 60 | START-X | 99.16 | 92.97 | 88.40 | 87.45 | [92.00](https://drive.google.com/drive/folders/1gy2oBNTcI_y8MTkZtETCx_P030Vp0U1e?usp=sharing) | 61 | 62 | Please set the `--eval` as `1`, `--target` as the domain index, and `--resume` as the saved path of the downloaded models, *e.g.*, `/trained/model/path/photo/model.pt` in `"scripts/test_model_performance.sh"`. Then you can directly run: 63 | 64 | ``` 65 | base scripts/test_model_performance.sh 66 | ``` 67 | 68 | You can also run the following code: 69 | 70 | ``` 71 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29000 --use_env ../main_dg.py \ 72 | --model vmamba_tiny \ 73 | --batch-size 32 \ 74 | --seed 0 \ 75 | --num_workers 16 \ 76 | --no_amp \ 77 | --data "PACS" \ 78 | --data_root [dataset_path] \ 79 | --target [domain_index, e.g., 0 for photo] \ 80 | --eval 1 \ 81 | --resume "/trained/model/path/photo/checkpoint.pth" 82 | ``` 83 | 84 | ## Citations 85 | ``` 86 | @inproceedings{guo2024start, 87 | title={START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation}, 88 | author={Guo, Jintao and Qi, Lei and Shi, Yinghuan and Gao, Yang}, 89 | booktitle={The Thirty-Eighth Annual Conference on Neural Information Processing Systems}, 90 | year={2024} 91 | } 92 | ``` 93 | 94 | ## Acknowledgement 95 | Part of our code is derived from the following repository. 96 | * [VMamba](https://github.com/MzeroMiko/VMamba): "Vmamba: Visual state space model", NeurIPS 2024 97 | * [Vim](https://github.com/hustvl/Vim): "Vision mamba: Efficient visual representation learning with bidirectional state space model", ICML 2024 98 | 99 | 100 | We thank to the authors for releasing their codes. Please also consider citing their work. 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingeringlight/START/898ac4c3a64ac440f5550b3796cc1e876c24bbc2/__init__.py -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | """ 5 | 3Augment implementation 6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 7 | and timm DA(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | from torchvision import transforms 11 | 12 | # error: cannot import name '_pil_interp' from 'timm.data.transforms' 13 | # from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 14 | 15 | # fix: timm version problem 16 | # from timm.data.transforms import str_pil_interp as _pil_interp 17 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 18 | 19 | import numpy as np 20 | from torchvision import datasets, transforms 21 | import random 22 | 23 | 24 | 25 | from PIL import ImageFilter, ImageOps 26 | import torchvision.transforms.functional as TF 27 | 28 | 29 | class GaussianBlur(object): 30 | """ 31 | Apply Gaussian Blur to the PIL image. 32 | """ 33 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 34 | self.prob = p 35 | self.radius_min = radius_min 36 | self.radius_max = radius_max 37 | 38 | def __call__(self, img): 39 | do_it = random.random() <= self.prob 40 | if not do_it: 41 | return img 42 | 43 | img = img.filter( 44 | ImageFilter.GaussianBlur( 45 | radius=random.uniform(self.radius_min, self.radius_max) 46 | ) 47 | ) 48 | return img 49 | 50 | class Solarization(object): 51 | """ 52 | Apply Solarization to the PIL image. 53 | """ 54 | def __init__(self, p=0.2): 55 | self.p = p 56 | 57 | def __call__(self, img): 58 | if random.random() < self.p: 59 | return ImageOps.solarize(img) 60 | else: 61 | return img 62 | 63 | class gray_scale(object): 64 | """ 65 | Apply Solarization to the PIL image. 66 | """ 67 | def __init__(self, p=0.2): 68 | self.p = p 69 | self.transf = transforms.Grayscale(3) 70 | 71 | def __call__(self, img): 72 | if random.random() < self.p: 73 | return self.transf(img) 74 | else: 75 | return img 76 | 77 | 78 | 79 | class horizontal_flip(object): 80 | """ 81 | Apply Solarization to the PIL image. 82 | """ 83 | def __init__(self, p=0.2,activate_pred=False): 84 | self.p = p 85 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 86 | 87 | def __call__(self, img): 88 | if random.random() < self.p: 89 | return self.transf(img) 90 | else: 91 | return img 92 | 93 | 94 | 95 | def new_data_aug_generator(args = None): 96 | img_size = args.input_size 97 | remove_random_resized_crop = args.src 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | primary_tfl = [] 100 | scale=(0.08, 1.0) 101 | interpolation='bicubic' 102 | if remove_random_resized_crop: 103 | primary_tfl = [ 104 | transforms.Resize(img_size, interpolation=3), 105 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 106 | transforms.RandomHorizontalFlip() 107 | ] 108 | else: 109 | primary_tfl = [ 110 | RandomResizedCropAndInterpolation( 111 | img_size, scale=scale, interpolation=interpolation), 112 | transforms.RandomHorizontalFlip() 113 | ] 114 | 115 | 116 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 117 | Solarization(p=1.0), 118 | GaussianBlur(p=1.0)])] 119 | 120 | if args.color_jitter is not None and not args.color_jitter==0: 121 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 122 | final_tfl = [ 123 | transforms.ToTensor(), 124 | transforms.Normalize( 125 | mean=torch.tensor(mean), 126 | std=torch.tensor(std)) 127 | ] 128 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 129 | -------------------------------------------------------------------------------- /data/JigsawLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from random import sample, random 8 | import sys 9 | import os 10 | 11 | def get_random_subset(names, labels, percent): 12 | """ 13 | 14 | :param names: list of names 15 | :param labels: list of labels 16 | :param percent: 0 < float < 1 17 | :return: 18 | """ 19 | samples = len(names) 20 | amount = int(samples * percent) 21 | random_index = sample(range(samples), amount) 22 | name_val = [names[k] for k in random_index] 23 | name_train = [v for k, v in enumerate(names) if k not in random_index] 24 | labels_val = [labels[k] for k in random_index] 25 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 26 | return name_train, name_val, labels_train, labels_val 27 | 28 | 29 | def _dataset_info(txt_labels): 30 | # read from the official split txt 31 | file_names = [] 32 | labels = [] 33 | 34 | for row in open(txt_labels, 'r'): 35 | row = row.split(' ') 36 | file_names.append(row[0]) 37 | labels.append(int(row[1])) 38 | 39 | return file_names, labels 40 | 41 | 42 | def find_classes(dir_name): 43 | if sys.version_info >= (3, 5): 44 | # Faster and available in Python 3.5 and above 45 | classes = [d.name for d in os.scandir(dir_name) if d.is_dir()] 46 | else: 47 | classes = [d for d in os.listdir(dir_name) if os.path.isdir(os.path.join(dir_name, d))] 48 | classes.sort() 49 | class_to_idx = {classes[i]: i+1 for i in range(len(classes))} 50 | return classes, class_to_idx 51 | 52 | 53 | def get_split_domain_info_from_dir(domain_path, dataset_name=None, val_percentage=None, domain_label=None): 54 | # read from the directory 55 | domain_name = domain_path.split("/")[-1] 56 | if dataset_name == "VLCS": 57 | name_train, name_val, labels_train, labels_val = [], [], [], [] 58 | classes, class_to_idx = find_classes(domain_path + "/full") 59 | for i, item in enumerate(classes): 60 | class_path = domain_path + "/" + "full" + "/" + item 61 | for root, _, fnames in sorted(os.walk(class_path)): 62 | for fname in sorted(fnames): 63 | path = os.path.join(domain_name, "full", item, fname) 64 | name_train.append(path) 65 | labels_train.append(class_to_idx[item]) 66 | 67 | for i, item in enumerate(classes): 68 | class_path = domain_path + "/" + "test" + "/" + item 69 | for root, _, fnames in sorted(os.walk(class_path)): 70 | for fname in sorted(fnames): 71 | path = os.path.join(domain_name, "test", item, fname) 72 | name_val.append(path) 73 | labels_val.append(class_to_idx[item]) 74 | domain_label_train = [domain_label for i in range(len(labels_train))] 75 | domain_label_val = [domain_label for i in range(len(labels_val))] 76 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 77 | 78 | elif dataset_name == "digits_dg": 79 | name_train, name_val, labels_train, labels_val = [], [], [], [] 80 | classes, class_to_idx = find_classes(domain_path + "/train") 81 | # train 82 | for i, item in enumerate(classes): 83 | class_path = domain_path + "/" + "train" + "/" + item 84 | for root, _, fnames in sorted(os.walk(class_path)): 85 | for fname in sorted(fnames): 86 | path = os.path.join(domain_name, "train", item, fname) 87 | name_train.append(path) 88 | labels_train.append(class_to_idx[item]) 89 | # val 90 | for i, item in enumerate(classes): 91 | class_path = domain_path + "/" + "val" + "/" + item 92 | for root, _, fnames in sorted(os.walk(class_path)): 93 | for fname in sorted(fnames): 94 | path = os.path.join(domain_name, "val", item, fname) 95 | name_val.append(path) 96 | labels_val.append(class_to_idx[item]) 97 | 98 | domain_label_train = [domain_label for i in range(len(labels_train))] 99 | domain_label_val = [domain_label for i in range(len(labels_val))] 100 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 101 | 102 | elif dataset_name == "OfficeHome" or dataset_name == "terra_incognita" or "PACS" in dataset_name or dataset_name == "DomainNet": 103 | names, labels = [], [] 104 | classes, class_to_idx = find_classes(domain_path) 105 | for i, item in enumerate(classes): 106 | class_path = domain_path + "/" + item 107 | for root, _, fnames in sorted(os.walk(class_path)): 108 | for fname in sorted(fnames): 109 | path = os.path.join(domain_name, item, fname) 110 | names.append(path) 111 | labels.append(class_to_idx[item]) 112 | name_train, name_val, labels_train, labels_val = get_random_subset(names, labels, val_percentage) 113 | domain_label_train = [domain_label for i in range(len(labels_train))] 114 | domain_label_val = [domain_label for i in range(len(labels_val))] 115 | return name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val 116 | 117 | else: 118 | raise ValueError("dataset is wrong.") 119 | 120 | 121 | def get_split_dataset_info_from_txt(txt_path, domain, domain_label, val_percentage=None): 122 | if "PACS" in txt_path: 123 | train_name = "_train_kfold.txt" 124 | val_name = "_crossval_kfold.txt" 125 | 126 | train_txt = txt_path + "/" + domain + train_name 127 | val_txt = txt_path + "/" + domain + val_name 128 | 129 | train_names, train_labels = _dataset_info(train_txt) 130 | val_names, val_labels = _dataset_info(val_txt) 131 | train_domain_labels = [domain_label for i in range(len(train_labels))] 132 | val_domain_labels = [domain_label for i in range(len(val_labels))] 133 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 134 | 135 | elif "miniDomainNet" in txt_path: 136 | # begin at 0, need to add 1 137 | train_name = "_train.txt" 138 | val_name = "_test.txt" 139 | train_txt = txt_path + "/" + domain + train_name 140 | val_txt = txt_path + "/" + domain + val_name 141 | 142 | train_names, train_labels = _dataset_info(train_txt) 143 | val_names, val_labels = _dataset_info(val_txt) 144 | train_labels = [label + 1 for label in train_labels] 145 | val_labels = [label + 1 for label in val_labels] 146 | 147 | names = train_names + val_names 148 | labels = train_labels + val_labels 149 | train_names, val_names, train_labels, val_labels = get_random_subset(names, labels, val_percentage) 150 | 151 | train_domain_labels = [domain_label for i in range(len(train_labels))] 152 | val_domain_labels = [domain_label for i in range(len(val_labels))] 153 | return train_names, val_names, train_labels, val_labels, train_domain_labels, val_domain_labels 154 | else: 155 | raise NotImplementedError 156 | 157 | 158 | def get_split_dataset_info(txt_list, val_percentage): 159 | names, labels = _dataset_info(txt_list) 160 | return get_random_subset(names, labels, val_percentage) 161 | 162 | 163 | # 原始Jigsaw 164 | class JigsawDataset(data.Dataset): 165 | def __init__(self, names, labels, jig_classes=100, img_transformer=None, tile_transformer=None, patches=True, bias_whole_image=None): 166 | self.data_path = "" 167 | self.names = names 168 | self.labels = labels 169 | 170 | self.N = len(self.names) 171 | self.permutations = self.__retrieve_permutations(jig_classes) 172 | self.grid_size = 3 173 | self.bias_whole_image = bias_whole_image 174 | if patches: 175 | self.patch_size = 64 176 | self._image_transformer = img_transformer 177 | self._augment_tile = tile_transformer 178 | if patches: 179 | self.returnFunc = lambda x: x 180 | else: 181 | def make_grid(x): 182 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 183 | self.returnFunc = make_grid 184 | 185 | def get_tile(self, img, n): 186 | w = float(img.size[0]) / self.grid_size 187 | y = int(n / self.grid_size) 188 | x = n % self.grid_size 189 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 190 | tile = self._augment_tile(tile) 191 | return tile 192 | 193 | def get_image(self, index): 194 | framename = self.data_path + '/' + self.names[index] 195 | img = Image.open(framename).convert('RGB') 196 | return self._image_transformer(img) 197 | 198 | def __getitem__(self, index): 199 | img = self.get_image(index) 200 | n_grids = self.grid_size ** 2 201 | tiles = [None] * n_grids 202 | for n in range(n_grids): 203 | tiles[n] = self.get_tile(img, n) 204 | 205 | order = np.random.randint(len(self.permutations) + 1) # added 1 for class 0: unsorted 206 | if self.bias_whole_image: 207 | if self.bias_whole_image > random(): 208 | order = 0 209 | if order == 0: 210 | data = tiles 211 | else: 212 | data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)] 213 | 214 | data = torch.stack(data, 0) 215 | return self.returnFunc(data), int(order), int(self.labels[index]) 216 | 217 | def __len__(self): 218 | return len(self.names) 219 | 220 | def __retrieve_permutations(self, classes): 221 | all_perm = np.load('permutations_%d.npy' % (classes)) 222 | # from range [1,9] to [0,8] 223 | if all_perm.min() == 1: 224 | all_perm = all_perm - 1 225 | 226 | return all_perm 227 | 228 | 229 | class JigsawTestDataset(JigsawDataset): 230 | def __init__(self, *args, **xargs): 231 | super().__init__(*args, **xargs) 232 | 233 | def __getitem__(self, index): 234 | framename = self.data_path + '/' + self.names[index] 235 | img = Image.open(framename).convert('RGB') 236 | return self._image_transformer(img), 0, int(self.labels[index]) 237 | 238 | 239 | class JigsawTestDatasetMultiple(JigsawDataset): 240 | def __init__(self, *args, **xargs): 241 | super().__init__(*args, **xargs) 242 | self._image_transformer = transforms.Compose([ 243 | transforms.Resize(255, Image.BILINEAR), 244 | ]) 245 | self._image_transformer_full = transforms.Compose([ 246 | transforms.Resize(225, Image.BILINEAR), 247 | transforms.ToTensor(), 248 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 249 | ]) 250 | self._augment_tile = transforms.Compose([ 251 | transforms.Resize((75, 75), Image.BILINEAR), 252 | transforms.ToTensor(), 253 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 254 | ]) 255 | 256 | def __getitem__(self, index): 257 | framename = self.data_path + '/' + self.names[index] 258 | _img = Image.open(framename).convert('RGB') 259 | img = self._image_transformer(_img) 260 | 261 | w = float(img.size[0]) / self.grid_size 262 | n_grids = self.grid_size ** 2 263 | images = [] 264 | jig_labels = [] 265 | tiles = [None] * n_grids 266 | for n in range(n_grids): 267 | y = int(n / self.grid_size) 268 | x = n % self.grid_size 269 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 270 | tile = self._augment_tile(tile) 271 | tiles[n] = tile 272 | for order in range(0, len(self.permutations)+1, 3): 273 | if order==0: 274 | data = tiles 275 | else: 276 | data = [tiles[self.permutations[order-1][t]] for t in range(n_grids)] 277 | data = self.returnFunc(torch.stack(data, 0)) 278 | images.append(data) 279 | jig_labels.append(order) 280 | images = torch.stack(images, 0) 281 | jig_labels = torch.LongTensor(jig_labels) 282 | return images, jig_labels, int(self.labels[index]) 283 | 284 | 285 | class JigsawNewDataset(data.Dataset): 286 | def __init__(self, names, labels, domain_labels, dataset_path, jig_classes=100, img_transformer=None, 287 | tile_transformer=None, patches=True, bias_whole_image=None): 288 | self.data_path = dataset_path 289 | 290 | self.names = names 291 | self.labels = labels 292 | self.domain_labels = domain_labels 293 | 294 | self.N = len(self.names) 295 | # self.permutations = self.__retrieve_permutations(jig_classes) 296 | self.grid_size = 3 297 | self.bias_whole_image = bias_whole_image 298 | if patches: 299 | self.patch_size = 64 300 | self._image_transformer = img_transformer 301 | self._augment_tile = tile_transformer 302 | if patches: 303 | self.returnFunc = lambda x: x 304 | else: 305 | def make_grid(x): 306 | return torchvision.utils.make_grid(x, self.grid_size, padding=0) 307 | 308 | self.returnFunc = make_grid 309 | 310 | def get_tile(self, img, n): 311 | w = float(img.size[0]) / self.grid_size 312 | y = int(n / self.grid_size) 313 | x = n % self.grid_size 314 | tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w]) 315 | tile = self._augment_tile(tile) 316 | return tile 317 | 318 | def get_image(self, index): 319 | framename = self.data_path + '/' + self.names[index] 320 | img = Image.open(framename).convert('RGB') 321 | return self._image_transformer(img) 322 | 323 | def __getitem__(self, index): 324 | framename = self.data_path + '/' + self.names[index] 325 | img = Image.open(framename).convert('RGB') 326 | # image, image_randaug, label, domain 327 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 328 | return self._image_transformer(img), int(self.labels[index] - 1) 329 | 330 | def __len__(self): 331 | return len(self.names) 332 | 333 | def __retrieve_permutations(self, classes): 334 | all_perm = np.load('permutations_%d.npy' % (classes)) 335 | # from range [1,9] to [0,8] 336 | if all_perm.min() == 1: 337 | all_perm = all_perm - 1 338 | return all_perm 339 | 340 | 341 | class JigsawTestNewDataset(JigsawNewDataset): 342 | def __init__(self, *args, **xargs): 343 | super().__init__(*args, **xargs) 344 | 345 | def __getitem__(self, index): 346 | framename = self.data_path + '/' + self.names[index] 347 | img = Image.open(framename).convert('RGB') 348 | # return self._image_transformer(img), 0, int(self.labels[index] - 1), int(self.domain_labels[index] - 1) 349 | return self._image_transformer(img), int(self.labels[index] - 1), self.names[index] -------------------------------------------------------------------------------- /data/concat_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch.utils.data import Dataset 5 | 6 | # This is a small variant of the ConcatDataset class, which also returns dataset index 7 | from .JigsawLoader import JigsawTestDatasetMultiple 8 | 9 | 10 | class ConcatDataset(Dataset): 11 | """ 12 | Dataset to concatenate multiple datasets. 13 | Purpose: useful to assemble different existing datasets, possibly 14 | large-scale datasets as the concatenation operation is done in an 15 | on-the-fly manner. 16 | 17 | Arguments: 18 | datasets (sequence): List of datasets to be concatenated 19 | """ 20 | 21 | @staticmethod 22 | def cumsum(sequence): 23 | r, s = [], 0 24 | for e in sequence: 25 | l = len(e) 26 | r.append(l + s) 27 | s += l 28 | return r 29 | 30 | def isMulti(self): 31 | return isinstance(self.datasets[0], JigsawTestDatasetMultiple) 32 | 33 | def __init__(self, datasets): 34 | super(ConcatDataset, self).__init__() 35 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 36 | self.datasets = list(datasets) 37 | self.cumulative_sizes = self.cumsum(self.datasets) 38 | 39 | def __len__(self): 40 | return self.cumulative_sizes[-1] 41 | 42 | def __getitem__(self, idx): 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | if dataset_idx == 0: 45 | sample_idx = idx 46 | else: 47 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 48 | return self.datasets[dataset_idx][sample_idx], dataset_idx 49 | 50 | @property 51 | def cummulative_sizes(self): 52 | warnings.warn("cummulative_sizes attribute is renamed to " 53 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 54 | return self.cumulative_sizes 55 | -------------------------------------------------------------------------------- /data/data_helper.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | 7 | from .JigsawLoader import * 8 | from .concat_dataset import ConcatDataset 9 | from .JigsawLoader import JigsawNewDataset, JigsawTestNewDataset 10 | 11 | # from samplers import BatchSchedulerSampler 12 | # from datasets import build_transform 13 | 14 | 15 | vlcs_datasets = ["CALTECH", "LABELME", "PASCAL", "SUN"] 16 | pacs_datasets = ["art_painting", "cartoon", "photo", "sketch"] 17 | officehome_datasets = ['Art', 'Clipart', 'Product', 'RealWorld'] 18 | available_datasets = officehome_datasets + pacs_datasets + vlcs_datasets 19 | 20 | 21 | class Subset(torch.utils.data.Dataset): 22 | def __init__(self, dataset, limit): 23 | indices = torch.randperm(len(dataset))[:limit] 24 | self.dataset = dataset 25 | self.indices = indices 26 | 27 | def __getitem__(self, idx): 28 | return self.dataset[self.indices[idx]] 29 | 30 | def __len__(self): 31 | return len(self.indices) 32 | 33 | 34 | def get_train_dataloader(args, patches): 35 | dataset_list = args.source 36 | assert isinstance(dataset_list, list) 37 | datasets = [] 38 | val_datasets = [] 39 | 40 | img_transformer, tile_transformer = get_train_transformers(args) 41 | img_transformer_val = get_val_transformer(args) 42 | limit = None 43 | 44 | if "PACS" in args.data_root: 45 | dataset_path = join(args.data_root, "kfold") 46 | elif args.data == "miniDomainNet": 47 | dataset_path = "/data/DataSets/" + "DomainNet" 48 | else: 49 | dataset_path = args.data_root 50 | 51 | for i, dname in enumerate(dataset_list): 52 | if args.data == "PACS": 53 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 54 | get_split_dataset_info_from_txt(txt_path=join(args.data_root, "pacs_label"), domain=dname, 55 | domain_label=i+1) 56 | elif args.data == "miniDomainNet": 57 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 58 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=dname, domain_label=i+1, 59 | val_percentage=args.val_size) 60 | else: 61 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 62 | get_split_domain_info_from_dir(join(dataset_path, dname), dataset_name=args.data, 63 | val_percentage=args.val_size, domain_label=i+1) 64 | 65 | train_dataset = JigsawNewDataset(name_train, labels_train, domain_labels_train, 66 | dataset_path=dataset_path, patches=patches, 67 | img_transformer=img_transformer, tile_transformer=tile_transformer, 68 | jig_classes=30) 69 | if limit: 70 | train_dataset = Subset(train_dataset, limit) 71 | datasets.append(train_dataset) 72 | if args.freq_analyse == 1: 73 | val_datasets.append( 74 | JigsawTestDatasetFreqAnalyse(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 75 | img_transformer=img_transformer_val, args=args, dataset_list=dataset_list)) 76 | else: 77 | val_datasets.append( 78 | JigsawTestNewDataset(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 79 | img_transformer=img_transformer_val, patches=patches, jig_classes=30)) 80 | dataset = ConcatDataset(datasets) 81 | val_dataset = ConcatDataset(val_datasets) 82 | 83 | 84 | if args.domain_sampler == 1: 85 | sampler = BatchSchedulerSampler(dataset, args.batch_size) 86 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 87 | pin_memory=True, drop_last=True, sampler=sampler) 88 | else: 89 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 90 | pin_memory=True, drop_last=True) 91 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 92 | pin_memory=True, drop_last=False) 93 | return loader, val_loader 94 | 95 | 96 | def get_val_dataloader(args, patches=False, tSNE_flag=0): 97 | if "PACS" in args.data_root: 98 | dataset_path = join(args.data_root, "kfold") 99 | elif args.data == "miniDomainNet": 100 | dataset_path = "/data/DataSets/" + "DomainNet" 101 | else: 102 | dataset_path = args.data_root 103 | 104 | if args.data == "miniDomainNet": 105 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = \ 106 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=args.target, domain_label=0, 107 | val_percentage=args.val_size) 108 | else: 109 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = get_split_domain_info_from_dir( 110 | join(dataset_path, args.target), dataset_name=args.data, val_percentage=args.val_size, domain_label=0) 111 | 112 | if tSNE_flag == 0: 113 | names = name_train + name_val 114 | labels = labels_train + labels_val 115 | domain_label = domain_label_train + domain_label_val 116 | else: 117 | names = name_val 118 | labels = labels_val 119 | domain_label = domain_label_val 120 | 121 | img_tr = get_val_transformer(args) 122 | dataset_list = args.source 123 | if args.freq_analyse == 1: 124 | val_dataset = JigsawTestDatasetFreqAnalyse(names, labels, domain_label, dataset_path=dataset_path, 125 | img_transformer=img_tr, args=args, dataset_list=dataset_list) 126 | else: 127 | val_dataset = JigsawTestNewDataset(names, labels, domain_label, dataset_path=dataset_path, patches=patches, 128 | img_transformer=img_tr, jig_classes=30) 129 | 130 | # if args.limit_target and len(val_dataset) > args.limit_target: 131 | # val_dataset = Subset(val_dataset, args.limit_target) 132 | # print("Using %d subset of val dataset" % args.limit_target) 133 | 134 | dataset = ConcatDataset([val_dataset]) 135 | loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 136 | pin_memory=True, drop_last=False) 137 | return loader 138 | 139 | 140 | def get_train_transformers(args): 141 | 142 | img_tr = [transforms.RandomResizedCrop((int(args.image_size), int(args.image_size)), (args.min_scale, args.max_scale))] 143 | if args.random_horiz_flip > 0.0: 144 | img_tr.append(transforms.RandomHorizontalFlip(args.random_horiz_flip)) 145 | if args.jitter > 0.0: 146 | img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter))) 147 | 148 | # this is special operation for JigenDG 149 | if args.gray_flag: 150 | img_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 151 | 152 | img_tr.append(transforms.ToTensor()) 153 | img_tr.append(transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 154 | 155 | tile_tr = [] 156 | if args.tile_random_grayscale: 157 | tile_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale)) 158 | tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] 159 | 160 | return transforms.Compose(img_tr), transforms.Compose(tile_tr) 161 | 162 | 163 | def get_val_transformer(args): 164 | img_tr = [ 165 | transforms.Resize((args.image_size, args.image_size)), 166 | transforms.ToTensor(), 167 | transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 168 | ] 169 | return transforms.Compose(img_tr) 170 | 171 | 172 | # def get_train_dataloader_RandAug(args, patches): 173 | # dataset_list = args.source 174 | # assert isinstance(dataset_list, list) 175 | # datasets = [] 176 | # val_datasets = [] 177 | # img_transformer, tile_transformer = get_train_transformers(args) 178 | # limit = args.limit_source 179 | # for dname in dataset_list: 180 | # name_train, labels_train = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_train_kfold.txt' % dname)) 181 | # name_val, labels_val = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_crossval_kfold.txt' % dname)) 182 | # 183 | # train_dataset = JigsawDatasetRandAug(name_train, labels_train, patches=patches, img_transformer=img_transformer, 184 | # bias_whole_image=args.bias_whole_image, args=args) 185 | # if limit: 186 | # train_dataset = Subset(train_dataset, limit) 187 | # datasets.append(train_dataset) 188 | # val_datasets.append( 189 | # JigsawTestDatasetRandAug(name_val, labels_val, img_transformer=get_val_transformer(args), 190 | # patches=patches, args=args)) 191 | # dataset = ConcatDataset(datasets) 192 | # val_dataset = ConcatDataset(val_datasets) 193 | # loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) 194 | # val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) 195 | # return loader, val_loader 196 | 197 | 198 | # def get_val_dataloader_RandAug(args, patches=False): 199 | # names, labels = _dataset_info(join('/data/DataSets/PACS', 'pacs_label', '%s_test_kfold.txt' % args.target)) 200 | # img_tr = get_val_transformer(args) 201 | # val_dataset = JigsawTestDatasetRandAug(names, labels, patches=patches, img_transformer=img_tr, args=args) 202 | # if args.limit_target and len(val_dataset) > args.limit_target: 203 | # val_dataset = Subset(val_dataset, args.limit_target) 204 | # print("Using %d subset of val dataset" % args.limit_target) 205 | # dataset = ConcatDataset([val_dataset]) 206 | # loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False) 207 | # return loader 208 | 209 | def get_train_dataset(args, patches): 210 | dataset_list = args.source 211 | assert isinstance(dataset_list, list) 212 | datasets = [] 213 | val_datasets = [] 214 | 215 | img_transformer, tile_transformer = get_train_transformers(args) 216 | img_transformer_val = get_val_transformer(args) 217 | 218 | limit = None 219 | 220 | if "PACS" in args.data_root: 221 | dataset_path = join(args.data_root, "kfold") 222 | elif args.data == "miniDomainNet": 223 | dataset_path = "/data/DataSets/" + "DomainNet" 224 | else: 225 | dataset_path = args.data_root 226 | 227 | for i, dname in enumerate(dataset_list): 228 | if args.data == "PACS": 229 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 230 | get_split_dataset_info_from_txt(txt_path=join(args.data_root, "pacs_label"), domain=dname, 231 | domain_label=i+1) 232 | elif args.data == "miniDomainNet": 233 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 234 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=dname, domain_label=i+1, 235 | val_percentage=args.val_size) 236 | else: 237 | name_train, name_val, labels_train, labels_val, domain_labels_train, domain_labels_val = \ 238 | get_split_domain_info_from_dir(join(dataset_path, dname), dataset_name=args.data, 239 | val_percentage=args.val_size, domain_label=i+1) 240 | 241 | train_dataset = JigsawNewDataset(name_train, labels_train, domain_labels_train, 242 | dataset_path=dataset_path, patches=patches, 243 | img_transformer=img_transformer, tile_transformer=tile_transformer, 244 | jig_classes=30) 245 | if limit: 246 | train_dataset = Subset(train_dataset, limit) 247 | datasets.append(train_dataset) 248 | val_datasets.append( 249 | JigsawTestNewDataset(name_val, labels_val, domain_labels_val, dataset_path=dataset_path, 250 | img_transformer=img_transformer_val, patches=patches, jig_classes=30)) 251 | dataset = ConcatDataset(datasets) 252 | val_dataset = ConcatDataset(val_datasets) 253 | 254 | return dataset, val_dataset 255 | 256 | 257 | def get_val_dataset(args, patches=False, tSNE_flag=0): 258 | if "PACS" in args.data_root: 259 | dataset_path = join(args.data_root, "kfold") 260 | elif args.data == "miniDomainNet": 261 | dataset_path = "/data/DataSets/" + "DomainNet" 262 | else: 263 | dataset_path = args.data_root 264 | 265 | if args.data == "miniDomainNet": 266 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = \ 267 | get_split_dataset_info_from_txt(txt_path=args.data_root, domain=args.target, domain_label=0, 268 | val_percentage=args.val_size) 269 | else: 270 | name_train, name_val, labels_train, labels_val, domain_label_train, domain_label_val = get_split_domain_info_from_dir( 271 | join(dataset_path, args.target), dataset_name=args.data, val_percentage=args.val_size, domain_label=0) 272 | 273 | if tSNE_flag == 0: 274 | names = name_train + name_val 275 | labels = labels_train + labels_val 276 | domain_label = domain_label_train + domain_label_val 277 | else: 278 | names = name_val 279 | labels = labels_val 280 | domain_label = domain_label_val 281 | 282 | img_tr = get_val_transformer(args) 283 | dataset_list = args.source 284 | val_dataset = JigsawTestNewDataset(names, labels, domain_label, dataset_path=dataset_path, patches=patches, 285 | img_transformer=img_tr, jig_classes=30) 286 | 287 | dataset = ConcatDataset([val_dataset]) 288 | return dataset 289 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.folder import ImageFolder, default_loader 7 | 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | from timm.data import create_transform 10 | 11 | """ Stanford Cars (Car) Dataset 12 | Created: Nov 15,2019 - Yuchong Gu 13 | Revised: Nov 15,2019 - Yuchong Gu 14 | """ 15 | import os 16 | # import pdb 17 | from PIL import Image 18 | import pickle 19 | 20 | 21 | # from scipy.io import loadmat 22 | 23 | 24 | class CarsDataset(Dataset): 25 | """ 26 | # Description: 27 | Dataset for retrieving Stanford Cars images and labels 28 | # Member Functions: 29 | __init__(self, phase, resize): initializes a dataset 30 | phase: a string in ['train', 'val', 'test'] 31 | resize: output shape/size of an image 32 | __getitem__(self, item): returns an image 33 | item: the idex of image in the whole dataset 34 | __len__(self): returns the length of dataset 35 | """ 36 | 37 | def __init__(self, root, train=True, transform=None): 38 | self.root = root 39 | self.phase = 'train' if train else 'test' 40 | # self.resize = resize 41 | self.num_classes = 196 42 | 43 | self.images = [] 44 | self.labels = [] 45 | 46 | list_path = os.path.join(root, 'cars_anno.pkl') 47 | 48 | list_mat = pickle.load(open(list_path, 'rb')) 49 | num_inst = len(list_mat['annotations']['relative_im_path'][0]) 50 | for i in range(num_inst): 51 | if self.phase == 'train' and list_mat['annotations']['test'][0][i].item() == 0: 52 | path = list_mat['annotations']['relative_im_path'][0][i].item() 53 | label = list_mat['annotations']['class'][0][i].item() 54 | self.images.append(path) 55 | self.labels.append(label) 56 | elif self.phase != 'train' and list_mat['annotations']['test'][0][i].item() == 1: 57 | path = list_mat['annotations']['relative_im_path'][0][i].item() 58 | label = list_mat['annotations']['class'][0][i].item() 59 | self.images.append(path) 60 | self.labels.append(label) 61 | 62 | print('Car Dataset with {} instances for {} phase'.format(len(self.images), self.phase)) 63 | 64 | # transform 65 | self.transform = transform 66 | 67 | def __getitem__(self, item): 68 | # image 69 | image = Image.open(os.path.join(self.root, self.images[item])).convert('RGB') # (C, H, W) 70 | image = self.transform(image) 71 | 72 | # return image and label 73 | return image, self.labels[item] - 1 # count begin from zero 74 | 75 | def __len__(self): 76 | return len(self.images) 77 | 78 | 79 | class INatDataset(ImageFolder): 80 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 81 | category='name', loader=default_loader): 82 | self.transform = transform 83 | self.loader = loader 84 | self.target_transform = target_transform 85 | self.year = year 86 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 87 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 88 | with open(path_json) as json_file: 89 | data = json.load(json_file) 90 | 91 | with open(os.path.join(root, 'categories.json')) as json_file: 92 | data_catg = json.load(json_file) 93 | 94 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 95 | 96 | with open(path_json_for_targeter) as json_file: 97 | data_for_targeter = json.load(json_file) 98 | 99 | targeter = {} 100 | indexer = 0 101 | for elem in data_for_targeter['annotations']: 102 | king = [] 103 | king.append(data_catg[int(elem['category_id'])][category]) 104 | if king[0] not in targeter.keys(): 105 | targeter[king[0]] = indexer 106 | indexer += 1 107 | self.nb_classes = len(targeter) 108 | 109 | self.samples = [] 110 | for elem in data['images']: 111 | cut = elem['file_name'].split('/') 112 | target_current = int(cut[2]) 113 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 114 | 115 | categors = data_catg[target_current] 116 | target_current_true = targeter[categors[category]] 117 | self.samples.append((path_current, target_current_true)) 118 | 119 | # __getitem__ and __len__ inherited from ImageFolder 120 | 121 | 122 | def build_dataset(is_train, args, infer_no_resize=False): 123 | transform = build_transform(is_train, args, infer_no_resize) 124 | 125 | if args.data_set == 'CIFAR100': 126 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 127 | nb_classes = 100 128 | elif args.data_set == 'CIFAR10': 129 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True) 130 | nb_classes = 10 131 | elif args.data_set == 'CARS': 132 | dataset = CarsDataset(args.data_path, train=is_train, transform=transform) 133 | nb_classes = 196 134 | elif args.data_set == 'FLOWERS': 135 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 136 | dataset = datasets.ImageFolder(root, transform=transform) 137 | nb_classes = 102 138 | elif args.data_set == 'IMNET': 139 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 140 | dataset = datasets.ImageFolder(root, transform=transform) 141 | nb_classes = 1000 142 | elif args.data_set == 'INAT': 143 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 144 | category=args.inat_category, transform=transform) 145 | nb_classes = dataset.nb_classes 146 | elif args.data_set == 'INAT19': 147 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 148 | category=args.inat_category, transform=transform) 149 | nb_classes = dataset.nb_classes 150 | 151 | return dataset, nb_classes 152 | 153 | 154 | def build_transform(is_train, args, infer_no_resize=False): 155 | if hasattr(args, 'arch'): 156 | if 'cait' in args.arch and not is_train: 157 | print('# using cait eval transform') 158 | transformations = {} 159 | transformations = transforms.Compose( 160 | [transforms.Resize(args.input_size, interpolation=3), 161 | transforms.CenterCrop(args.input_size), 162 | transforms.ToTensor(), 163 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 164 | return transformations 165 | 166 | if infer_no_resize: 167 | print('# using cait eval transform') 168 | transformations = {} 169 | transformations = transforms.Compose( 170 | [transforms.Resize(args.input_size, interpolation=3), 171 | transforms.CenterCrop(args.input_size), 172 | transforms.ToTensor(), 173 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 174 | return transformations 175 | 176 | resize_im = args.input_size > 32 177 | if is_train: 178 | # this should always dispatch to transforms_imagenet_train 179 | transform = create_transform( 180 | input_size=args.input_size, 181 | is_training=True, 182 | color_jitter=args.color_jitter, 183 | auto_augment=args.aa, 184 | interpolation=args.train_interpolation, 185 | re_prob=args.reprob, 186 | re_mode=args.remode, 187 | re_count=args.recount, 188 | ) 189 | if not resize_im: 190 | # replace RandomResizedCropAndInterpolation with 191 | # RandomCrop 192 | transform.transforms[0] = transforms.RandomCrop( 193 | args.input_size, padding=4) 194 | return transform 195 | 196 | t = [] 197 | if resize_im: 198 | size = int((256 / 224) * args.input_size) 199 | t.append( 200 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 201 | ) 202 | t.append(transforms.CenterCrop(args.input_size)) 203 | 204 | t.append(transforms.ToTensor()) 205 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 206 | return transforms.Compose(t) 207 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import random 4 | from collections import defaultdict 5 | from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler 6 | import math 7 | 8 | 9 | class BatchSchedulerSampler(Sampler): 10 | """ 11 | iterate over tasks and provide a random batch per task in each mini-batch 12 | """ 13 | def __init__(self, dataset, batch_size): 14 | self.dataset = dataset 15 | self.batch_size = batch_size 16 | self.number_of_datasets = len(dataset.datasets) 17 | self.mini_batch_size = int(batch_size / self.number_of_datasets) 18 | # self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets]) 19 | self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets]) 20 | 21 | def __len__(self): 22 | return self.mini_batch_size * math.ceil(self.largest_dataset_size / self.mini_batch_size) * len(self.dataset.datasets) 23 | 24 | def __iter__(self): 25 | samplers_list = [] 26 | sampler_iterators = [] 27 | for dataset_idx in range(self.number_of_datasets): 28 | cur_dataset = self.dataset.datasets[dataset_idx] 29 | sampler = RandomSampler(cur_dataset) 30 | samplers_list.append(sampler) 31 | cur_sampler_iterator = sampler.__iter__() 32 | sampler_iterators.append(cur_sampler_iterator) 33 | 34 | push_index_val = [0] + self.dataset.cumulative_sizes[:-1] 35 | step = self.batch_size 36 | # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets 37 | epoch_samples = self.largest_dataset_size * self.number_of_datasets 38 | 39 | final_samples_list = [] # this is a list of indexes from the combined dataset 40 | for _ in range(0, epoch_samples, step): 41 | for i in range(self.number_of_datasets): 42 | cur_batch_sampler = sampler_iterators[i] 43 | cur_samples = [] 44 | for _ in range(self.mini_batch_size): 45 | try: 46 | cur_sample_org = cur_batch_sampler.__next__() 47 | cur_sample = cur_sample_org + push_index_val[i] 48 | cur_samples.append(cur_sample) 49 | except StopIteration: 50 | # got to the end of iterator - restart the iterator and continue to get samples 51 | # until reaching "epoch_samples" 52 | sampler_iterators[i] = samplers_list[i].__iter__() 53 | cur_batch_sampler = sampler_iterators[i] 54 | cur_sample_org = cur_batch_sampler.__next__() 55 | cur_sample = cur_sample_org + push_index_val[i] 56 | cur_samples.append(cur_sample) 57 | final_samples_list.extend(cur_samples) 58 | 59 | return iter(final_samples_list) 60 | 61 | 62 | class RandomDomainSampler(Sampler): 63 | """Randomly samples N domains each with K images 64 | to form a minibatch of size N*K. 65 | Args: 66 | data_source (list): list of Datums. 67 | batch_size (int): batch size. 68 | n_domain (int): number of domains to sample in a minibatch. 69 | """ 70 | 71 | def __init__(self, data_source, batch_size, n_domain): 72 | self.data_source = data_source 73 | 74 | # Keep track of image indices for each domain 75 | self.domain_dict = defaultdict(list) 76 | for i, item in enumerate(data_source): 77 | self.domain_dict[item.domain].append(i) 78 | self.domains = list(self.domain_dict.keys()) 79 | 80 | # Make sure each domain has equal number of images 81 | if n_domain is None or n_domain <= 0: 82 | n_domain = len(self.domains) 83 | assert batch_size % n_domain == 0 84 | self.n_img_per_domain = batch_size // n_domain 85 | 86 | self.batch_size = batch_size 87 | # n_domain denotes number of domains sampled in a minibatch 88 | self.n_domain = n_domain 89 | self.length = len(list(self.__iter__())) 90 | 91 | def __iter__(self): 92 | domain_dict = copy.deepcopy(self.domain_dict) 93 | final_idxs = [] 94 | stop_sampling = False 95 | 96 | while not stop_sampling: 97 | selected_domains = random.sample(self.domains, self.n_domain) 98 | 99 | for domain in selected_domains: 100 | idxs = domain_dict[domain] 101 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 102 | final_idxs.extend(selected_idxs) 103 | 104 | for idx in selected_idxs: 105 | domain_dict[domain].remove(idx) 106 | 107 | remaining = len(domain_dict[domain]) 108 | if remaining < self.n_img_per_domain: 109 | stop_sampling = True 110 | 111 | return iter(final_idxs) 112 | 113 | def __len__(self): 114 | return self.length 115 | 116 | 117 | class SeqDomainSampler(Sampler): 118 | """Sequential domain sampler, which randomly samples K 119 | images from each domain to form a minibatch. 120 | Args: 121 | data_source (list): list of Datums. 122 | batch_size (int): batch size. 123 | """ 124 | 125 | def __init__(self, data_source, batch_size): 126 | self.data_source = data_source 127 | 128 | # Keep track of image indices for each domain 129 | self.domain_dict = defaultdict(list) 130 | for i, item in enumerate(data_source): 131 | self.domain_dict[item.domain].append(i) 132 | self.domains = list(self.domain_dict.keys()) 133 | self.domains.sort() 134 | 135 | # Make sure each domain has equal number of images 136 | n_domain = len(self.domains) 137 | assert batch_size % n_domain == 0 138 | self.n_img_per_domain = batch_size // n_domain 139 | 140 | self.batch_size = batch_size 141 | # n_domain denotes number of domains sampled in a minibatch 142 | self.n_domain = n_domain 143 | self.length = len(list(self.__iter__())) 144 | 145 | def __iter__(self): 146 | domain_dict = copy.deepcopy(self.domain_dict) 147 | final_idxs = [] 148 | stop_sampling = False 149 | 150 | while not stop_sampling: 151 | for domain in self.domains: 152 | idxs = domain_dict[domain] 153 | selected_idxs = random.sample(idxs, self.n_img_per_domain) 154 | final_idxs.extend(selected_idxs) 155 | 156 | for idx in selected_idxs: 157 | domain_dict[domain].remove(idx) 158 | 159 | remaining = len(domain_dict[domain]) 160 | if remaining < self.n_img_per_domain: 161 | stop_sampling = True 162 | 163 | return iter(final_idxs) 164 | 165 | def __len__(self): 166 | return self.length 167 | 168 | 169 | class RandomClassSampler(Sampler): 170 | """Randomly samples N classes each with K instances to 171 | form a minibatch of size N*K. 172 | Modified from https://github.com/KaiyangZhou/deep-person-reid. 173 | Args: 174 | data_source (list): list of Datums. 175 | batch_size (int): batch size. 176 | n_ins (int): number of instances per class to sample in a minibatch. 177 | """ 178 | 179 | def __init__(self, data_source, batch_size, n_ins): 180 | if batch_size < n_ins: 181 | raise ValueError( 182 | "batch_size={} must be no less " 183 | "than n_ins={}".format(batch_size, n_ins) 184 | ) 185 | 186 | self.data_source = data_source 187 | self.batch_size = batch_size 188 | self.n_ins = n_ins 189 | self.ncls_per_batch = self.batch_size // self.n_ins 190 | self.index_dic = defaultdict(list) 191 | for index, item in enumerate(data_source): 192 | self.index_dic[item.label].append(index) 193 | self.labels = list(self.index_dic.keys()) 194 | assert len(self.labels) >= self.ncls_per_batch 195 | 196 | # estimate number of images in an epoch 197 | self.length = len(list(self.__iter__())) 198 | 199 | def __iter__(self): 200 | batch_idxs_dict = defaultdict(list) 201 | 202 | for label in self.labels: 203 | idxs = copy.deepcopy(self.index_dic[label]) 204 | if len(idxs) < self.n_ins: 205 | idxs = np.random.choice(idxs, size=self.n_ins, replace=True) 206 | random.shuffle(idxs) 207 | batch_idxs = [] 208 | for idx in idxs: 209 | batch_idxs.append(idx) 210 | if len(batch_idxs) == self.n_ins: 211 | batch_idxs_dict[label].append(batch_idxs) 212 | batch_idxs = [] 213 | 214 | avai_labels = copy.deepcopy(self.labels) 215 | final_idxs = [] 216 | 217 | while len(avai_labels) >= self.ncls_per_batch: 218 | selected_labels = random.sample(avai_labels, self.ncls_per_batch) 219 | for label in selected_labels: 220 | batch_idxs = batch_idxs_dict[label].pop(0) 221 | final_idxs.extend(batch_idxs) 222 | if len(batch_idxs_dict[label]) == 0: 223 | avai_labels.remove(label) 224 | 225 | return iter(final_idxs) 226 | 227 | def __len__(self): 228 | return self.length 229 | 230 | 231 | def build_sampler( 232 | sampler_type, 233 | cfg=None, 234 | data_source=None, 235 | batch_size=32, 236 | n_domain=0, 237 | n_ins=16 238 | ): 239 | if sampler_type == "RandomSampler": 240 | return RandomSampler(data_source) 241 | 242 | elif sampler_type == "SequentialSampler": 243 | return SequentialSampler(data_source) 244 | 245 | elif sampler_type == "RandomDomainSampler": 246 | return RandomDomainSampler(data_source, batch_size, n_domain) 247 | 248 | elif sampler_type == "SeqDomainSampler": 249 | return SeqDomainSampler(data_source, batch_size) 250 | 251 | elif sampler_type == "RandomClassSampler": 252 | return RandomClassSampler(data_source, batch_size, n_ins) 253 | 254 | else: 255 | raise ValueError("Unknown sampler type: {}".format(sampler_type)) 256 | 257 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args): 57 | transform = build_transform(is_train, args) 58 | 59 | if args.data_set == 'CIFAR': 60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 61 | nb_classes = 100 62 | elif args.data_set == 'IMNET': 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | nb_classes = 1000 66 | elif args.data_set == 'INAT': 67 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 68 | category=args.inat_category, transform=transform) 69 | nb_classes = dataset.nb_classes 70 | elif args.data_set == 'INAT19': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | 75 | return dataset, nb_classes 76 | 77 | 78 | def build_transform(is_train, args): 79 | resize_im = args.input_size > 32 80 | if is_train: 81 | # this should always dispatch to transforms_imagenet_train 82 | transform = create_transform( 83 | input_size=args.input_size, 84 | is_training=True, 85 | color_jitter=args.color_jitter, 86 | auto_augment=args.aa, 87 | interpolation=args.train_interpolation, 88 | re_prob=args.reprob, 89 | re_mode=args.remode, 90 | re_count=args.recount, 91 | ) 92 | if not resize_im: 93 | # replace RandomResizedCropAndInterpolation with 94 | # RandomCrop 95 | transform.transforms[0] = transforms.RandomCrop( 96 | args.input_size, padding=4) 97 | return transform 98 | 99 | t = [] 100 | if resize_im: 101 | size = int(args.input_size / args.eval_crop_ratio) 102 | t.append( 103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(args.input_size)) 106 | 107 | t.append(transforms.ToTensor()) 108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 109 | return transforms.Compose(t) 110 | -------------------------------------------------------------------------------- /engine_dg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | import timm 13 | from timm.data import Mixup 14 | from timm.utils import accuracy, ModelEma 15 | 16 | from losses import DistillationLoss 17 | import utils as utils 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 23 | set_training_mode=True, args=None): 24 | model.train(set_training_mode) 25 | metric_logger = utils.MetricLogger(delimiter=" ") 26 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | print_freq = 10 29 | 30 | if args.cosub: 31 | criterion = torch.nn.BCEWithLogitsLoss() 32 | 33 | for (samples, targets), domains in metric_logger.log_every(data_loader, print_freq, header): 34 | 35 | samples = samples.to(device, non_blocking=True) 36 | targets = targets.to(device, non_blocking=True) 37 | 38 | if mixup_fn is not None: 39 | samples, targets = mixup_fn(samples, targets) 40 | 41 | if args.cosub: 42 | samples = torch.cat((samples, samples), dim=0) 43 | 44 | if args.bce_loss: 45 | targets = targets.gt(0.0).type(targets.dtype) 46 | 47 | with amp_autocast(): 48 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, 49 | if_random_token_rank=args.if_random_token_rank) 50 | # outputs = model(samples) 51 | if not args.cosub: 52 | loss = criterion(samples, outputs, targets) 53 | else: 54 | outputs = torch.split(outputs, outputs.shape[0] // 2, dim=0) 55 | loss = 0.25 * criterion(outputs[0], targets) 56 | loss = loss + 0.25 * criterion(outputs[1], targets) 57 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 58 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 59 | 60 | if args.if_nan2num: 61 | with amp_autocast(): 62 | loss = torch.nan_to_num(loss) 63 | 64 | loss_value = loss.item() 65 | 66 | if not math.isfinite(loss_value): 67 | print("Loss is {}, stopping training".format(loss_value)) 68 | if args.if_continue_inf: 69 | optimizer.zero_grad() 70 | continue 71 | else: 72 | sys.exit(1) 73 | 74 | optimizer.zero_grad() 75 | 76 | # this attribute is added by timm on one optimizer (adahessian) 77 | if isinstance(loss_scaler, timm.utils.NativeScaler): 78 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 79 | loss_scaler(loss, optimizer, clip_grad=max_norm, 80 | parameters=model.parameters(), create_graph=is_second_order) 81 | else: 82 | loss.backward() 83 | if max_norm != None: 84 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 85 | optimizer.step() 86 | 87 | torch.cuda.synchronize() 88 | if model_ema is not None: 89 | model_ema.update(model) 90 | 91 | metric_logger.update(loss=loss_value) 92 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 93 | # gather the stats from all processes 94 | metric_logger.synchronize_between_processes() 95 | print("Averaged stats:", metric_logger) 96 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 97 | 98 | 99 | @torch.no_grad() 100 | def evaluate(data_loader, model, device, amp_autocast): 101 | criterion = torch.nn.CrossEntropyLoss() 102 | 103 | metric_logger = utils.MetricLogger(delimiter=" ") 104 | header = 'Test:' 105 | 106 | # switch to evaluation mode 107 | model.eval() 108 | 109 | for (images, target, _), _ in metric_logger.log_every(data_loader, 200, header): 110 | images = images.to(device, non_blocking=True) 111 | 112 | target = target.to(device, non_blocking=True) 113 | 114 | with amp_autocast(): 115 | output = model(images) 116 | loss = criterion(output, target) 117 | 118 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 119 | 120 | batch_size = images.shape[0] 121 | metric_logger.update(loss=loss.item()) 122 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 123 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 124 | # gather the stats from all processes 125 | metric_logger.synchronize_between_processes() 126 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 127 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 128 | 129 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 130 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | from cait_models import * 5 | from resmlp_models import * 6 | #from patchconvnet_models import * 7 | 8 | dependencies = ["torch", "torchvision", "timm"] 9 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /main_dg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | 11 | from pathlib import Path 12 | 13 | from timm.data import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.scheduler import create_scheduler 17 | from timm.optim import create_optimizer 18 | from timm.utils import NativeScaler, get_state_dict, ModelEma 19 | 20 | from datasets import build_dataset 21 | from engine_dg import train_one_epoch, evaluate 22 | from losses import DistillationLoss 23 | from samplers import RASampler 24 | from augment import new_data_aug_generator 25 | 26 | from contextlib import suppress 27 | 28 | from data import data_helper 29 | 30 | import os 31 | 32 | import models_mamba 33 | import utils 34 | 35 | import models.vmamba 36 | 37 | 38 | # log about 39 | import mlflow 40 | 41 | from thop import profile 42 | from thop import clever_format 43 | 44 | import threading 45 | 46 | lock = threading.Lock() 47 | def create_folder(folder_path): 48 | with lock: 49 | if not os.path.exists(folder_path): 50 | os.makedirs(folder_path) 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 54 | parser.add_argument('--target', default=1, type=int) 55 | 56 | parser.add_argument('--batch-size', default=64, type=int) 57 | parser.add_argument('--epochs', default=50, type=int) 58 | parser.add_argument('--bce-loss', action='store_true') 59 | parser.add_argument('--unscale-lr', action='store_true') 60 | 61 | parser.add_argument('--flops_flag', default=0, type=int, help="whether show flops") 62 | 63 | parser.add_argument('--ssm_version', default="v2") 64 | parser.add_argument('--spatial_aug_flag', default=1, type=int, help="1: MixStyle; 2: DSU; 3: ALOFT") 65 | parser.add_argument('--START_flag', default=0, type=int, help="") 66 | parser.add_argument('--START_p', default=1.0, type=float, help="") 67 | parser.add_argument('--START_token_prob', default=0.75, type=float, help="") 68 | parser.add_argument('--START_batch_prob', default=0.5, type=float, help="") 69 | parser.add_argument('--START_attention_mode', default=0, type=int, help="") 70 | parser.add_argument('--Vim_START_flag', default=0, type=int, help="") 71 | 72 | # Model parameters 73 | parser.add_argument('--model', 74 | default='vmamba_tiny', 75 | type=str, metavar='MODEL', 76 | help='Name of model to train') 77 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 78 | parser.add_argument('--image-size', default=224, type=int, help='images input size') 79 | 80 | parser.add_argument("--min_scale", default=0.8, type=float, help="Minimum scale percent") 81 | parser.add_argument("--max_scale", default=1.0, type=float, help="Maximum scale percent") 82 | parser.add_argument("--gray_flag", default=1, type=int, help="whether use random gray") 83 | parser.add_argument("--random_horiz_flip", default=0.5, type=float, help="Chance of random horizontal flip") 84 | parser.add_argument("--jitter", default=0.4, type=float, help="Color jitter amount") 85 | parser.add_argument("--tile_random_grayscale", default=0.1, type=float, 86 | help="Chance of randomly greyscaling a tile") 87 | 88 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 89 | help='Dropout rate (default: 0.)') 90 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 91 | help='Drop path rate (default: 0.1)') 92 | 93 | parser.add_argument('--model-ema', action='store_true') 94 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 95 | parser.set_defaults(model_ema=True) 96 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 97 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 98 | 99 | # Optimizer parameters 100 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 101 | help='Optimizer (default: "adamw"') 102 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 103 | help='Optimizer Epsilon (default: 1e-8)') 104 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 105 | help='Optimizer Betas (default: None, use opt default)') 106 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 107 | help='Clip gradient norm (default: None, no clipping)') 108 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 109 | help='SGD momentum (default: 0.9)') 110 | parser.add_argument('--weight-decay', type=float, default=0.05, 111 | help='weight decay (default: 0.05)') 112 | # Learning rate schedule parameters 113 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 114 | help='LR scheduler (default: "cosine"') 115 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 116 | help='learning rate (default: 5e-4)') 117 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 118 | help='learning rate noise on/off epoch percentages') 119 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 120 | help='learning rate noise limit percent (default: 0.67)') 121 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 122 | help='learning rate noise std-dev (default: 1.0)') 123 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 124 | help='warmup learning rate (default: 1e-6)') 125 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 126 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 127 | 128 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 129 | help='epoch interval to decay LR') 130 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 131 | help='epochs to warmup LR, if scheduler supports') 132 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 133 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 134 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 135 | help='patience epochs for Plateau LR scheduler (default: 10') 136 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 137 | help='LR decay rate (default: 0.1)') 138 | 139 | # Augmentation parameters 140 | parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', 141 | help='Color jitter factor (default: 0.3)') 142 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 143 | help='Use AutoAugment policy. "v0" or "original". " + \ 144 | "(default: rand-m9-mstd0.5-inc1)'), 145 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 146 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 147 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 148 | 149 | parser.add_argument('--repeated-aug', action='store_true') 150 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 151 | parser.set_defaults(repeated_aug=True) 152 | 153 | parser.add_argument('--train-mode', action='store_true') 154 | parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') 155 | parser.set_defaults(train_mode=True) 156 | 157 | parser.add_argument('--ThreeAugment', action='store_true') # 3augment 158 | 159 | parser.add_argument('--src', action='store_true') # simple random crop 160 | 161 | # * Random Erase params 162 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 163 | help='Random erase prob (default: 0.25)') 164 | parser.add_argument('--remode', type=str, default='pixel', 165 | help='Random erase mode (default: "pixel")') 166 | parser.add_argument('--recount', type=int, default=1, 167 | help='Random erase count (default: 1)') 168 | parser.add_argument('--resplit', action='store_true', default=False, 169 | help='Do not random erase first (clean) augmentation split') 170 | 171 | # * Mixup params 172 | parser.add_argument('--mixup', type=float, default=0.8, 173 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 174 | parser.add_argument('--cutmix', type=float, default=1.0, 175 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 176 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 177 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 178 | parser.add_argument('--mixup-prob', type=float, default=1.0, 179 | help='Probability of performing mixup or cutmix when either/both is enabled') 180 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 181 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 182 | parser.add_argument('--mixup-mode', type=str, default='batch', 183 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 184 | 185 | # Distillation parameters 186 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 187 | help='Name of teacher model to train (default: "regnety_160"') 188 | parser.add_argument('--teacher-path', type=str, default='') 189 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 190 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 191 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 192 | 193 | # * Cosub params 194 | parser.add_argument('--cosub', action='store_true') 195 | 196 | # * Finetuning params 197 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 198 | parser.add_argument('--attn-only', action='store_true') 199 | 200 | # Dataset parameters 201 | parser.add_argument('--data_root', default='/data/DataSets/', type=str, 202 | help='dataset path') 203 | parser.add_argument('--data', default='PACS', 204 | choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'PACS', 'OfficeHome', 'VLCS', 'digits_dg', 'terra_incognita', 'DomainNet'], 205 | type=str, help='Image Net dataset path') 206 | parser.add_argument('--inat-category', default='name', 207 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 208 | type=str, help='semantic granularity') 209 | 210 | # parser.add_argument('--output_dir', default='/output/Vim_results_test/', 211 | # help='path where to save, empty for no saving') 212 | parser.add_argument('--output_dir', default='/data/gjt/Mamba/Vim_for_DG/vim/output/', 213 | help='path where to save, empty for no saving') 214 | parser.add_argument('--device', default='cuda', 215 | help='device to use for training / testing') 216 | parser.add_argument('--seed', default=0, type=int) 217 | parser.add_argument('--resume', default='', help='resume from checkpoint') 218 | 219 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 220 | help='start epoch') 221 | parser.add_argument('--eval', default=0, type=int, help='Perform evaluation only') 222 | parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") 223 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 224 | parser.add_argument('--num_workers', default=10, type=int) 225 | parser.add_argument('--pin-mem', action='store_true', 226 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 227 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 228 | help='') 229 | parser.set_defaults(pin_mem=True) 230 | 231 | # distributed training parameters 232 | parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') 233 | parser.add_argument('--world_size', default=1, type=int, 234 | help='number of distributed processes') 235 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 236 | 237 | # amp about 238 | parser.add_argument('--if_amp', action='store_true') 239 | parser.add_argument('--no_amp', action='store_false', dest='if_amp') 240 | parser.set_defaults(if_amp=True) 241 | 242 | # if continue with inf 243 | parser.add_argument('--if_continue_inf', action='store_true') 244 | parser.add_argument('--no_continue_inf', action='store_false', dest='if_continue_inf') 245 | parser.set_defaults(if_continue_inf=False) 246 | 247 | # if use nan to num 248 | parser.add_argument('--if_nan2num', action='store_true') 249 | parser.add_argument('--no_nan2num', action='store_false', dest='if_nan2num') 250 | parser.set_defaults(if_nan2num=False) 251 | 252 | # if use random token position 253 | parser.add_argument('--if_random_cls_token_position', action='store_true') 254 | parser.add_argument('--no_random_cls_token_position', action='store_false', dest='if_random_cls_token_position') 255 | parser.set_defaults(if_random_cls_token_position=False) 256 | 257 | # if use random token rank 258 | parser.add_argument('--if_random_token_rank', action='store_true') 259 | parser.add_argument('--no_random_token_rank', action='store_false', dest='if_random_token_rank') 260 | parser.set_defaults(if_random_token_rank=False) 261 | 262 | parser.add_argument('--local-rank', default=0, type=int) 263 | return parser 264 | 265 | domain_map = { 266 | 'PACS': ['photo', 'art_painting', 'cartoon', 'sketch'], 267 | 'PACS_random_split': ['photo', 'art_painting', 'cartoon', 'sketch'], 268 | 'OfficeHome': ['Art', 'Clipart', 'Product', 'RealWorld'], 269 | 'VLCS': ["CALTECH", "LABELME", "PASCAL", "SUN"], 270 | 'digits_dg': ['mnist', 'mnist_m', 'svhn', 'syn'], 271 | 'miniDomainNet': ['clipart', 'painting', 'real', 'sketch'], 272 | 'terra_incognita': ['location_46', 'location_43', 'location_38', 'location_100'], 273 | 'DomainNet': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'], 274 | } 275 | 276 | 277 | classes_map = { 278 | 'PACS': 7, 279 | 'PACS_random_split': 7, 280 | 'OfficeHome': 65, 281 | 'VLCS': 5, 282 | 'digits_dg': 10, 283 | 'miniDomainNet': 126, 284 | 'terra_incognita': 10, 285 | 'DomainNet': 345, 286 | } 287 | 288 | val_size_map = { 289 | 'PACS': 0.1, 290 | 'PACS_random_split': 0.1, 291 | 'OfficeHome': 0.1, 292 | 'VLCS': 0.3, 293 | 'digits_dg': 0.2, 294 | 'miniDomainNet': 0.3, 295 | 'terra_incognita': 0.2, 296 | 'DomainNet': 0.2, 297 | } 298 | 299 | def get_domain(name): 300 | if name not in domain_map: 301 | raise ValueError('Name of dataset unknown %s' %name) 302 | return domain_map[name] 303 | 304 | def main(args): 305 | utils.init_distributed_mode(args) 306 | 307 | domain = get_domain(args.data) 308 | args.target = domain.pop(args.target) 309 | args.source = domain 310 | print("Target domain: {}".format(args.target)) 311 | args.data_root = os.path.join(args.data_root, "PACS") if "PACS" in args.data else os.path.join(args.data_root, 312 | args.data) 313 | args.nb_classes = classes_map[args.data] 314 | args.n_domains = len(domain) 315 | args.val_size = val_size_map[args.data] 316 | 317 | # print(args) 318 | 319 | if args.distillation_type != 'none' and args.finetune and not args.eval: 320 | raise NotImplementedError("Finetuning with distillation not yet supported") 321 | 322 | device = torch.device(args.device) 323 | 324 | # fix the seed for reproducibility 325 | seed = args.seed + utils.get_rank() 326 | torch.manual_seed(seed) 327 | np.random.seed(seed) 328 | # random.seed(seed) 329 | 330 | cudnn.benchmark = True 331 | 332 | # log about 333 | run_name = args.output_dir.split("/")[-1] 334 | # args.gpu = 1 335 | if args.local_rank == 0 and args.gpu == 0: 336 | mlflow.start_run(run_name=run_name) 337 | for key, value in vars(args).items(): 338 | mlflow.log_param(key, value) 339 | 340 | dataset_train, dataset_val = data_helper.get_train_dataset(args, patches=False) 341 | dataset_test = data_helper.get_val_dataset(args, patches=False) 342 | 343 | if args.distributed: 344 | num_tasks = utils.get_world_size() 345 | global_rank = utils.get_rank() 346 | if args.repeated_aug: 347 | sampler_train = RASampler( 348 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 349 | ) 350 | else: 351 | sampler_train = torch.utils.data.DistributedSampler( 352 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 353 | ) 354 | if args.dist_eval: 355 | if len(dataset_val) % num_tasks != 0: 356 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 357 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 358 | 'equal num of samples per-process.') 359 | sampler_val = torch.utils.data.DistributedSampler( 360 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 361 | 362 | sampler_test = torch.utils.data.DistributedSampler( 363 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False) 364 | else: 365 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 366 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 367 | else: 368 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 369 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 370 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 371 | 372 | data_loader_train = torch.utils.data.DataLoader( 373 | dataset_train, sampler=sampler_train, 374 | batch_size=args.batch_size, 375 | num_workers=args.num_workers, 376 | pin_memory=args.pin_mem, 377 | drop_last=True, 378 | ) 379 | if args.ThreeAugment: 380 | data_loader_train.dataset.transform = new_data_aug_generator(args) 381 | 382 | data_loader_val = torch.utils.data.DataLoader( 383 | dataset_val, sampler=sampler_val, 384 | batch_size=int(1.5 * args.batch_size), 385 | num_workers=args.num_workers, 386 | pin_memory=args.pin_mem, 387 | drop_last=False 388 | ) 389 | 390 | data_loader_test = torch.utils.data.DataLoader( 391 | dataset_test, sampler=sampler_test, 392 | batch_size=int(1.5 * args.batch_size), 393 | num_workers=args.num_workers, 394 | pin_memory=args.pin_mem, 395 | drop_last=False 396 | ) 397 | 398 | mixup_fn = None 399 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 400 | if mixup_active: 401 | mixup_fn = Mixup( 402 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 403 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 404 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 405 | 406 | print(f"Creating model: {args.model}") 407 | 408 | if "vmamba" in args.model: 409 | model = create_model( 410 | args.model, 411 | pretrained=False, 412 | num_classes=args.nb_classes, 413 | drop_block_rate=None, 414 | img_size=args.input_size, 415 | spatial_aug_flag=args.spatial_aug_flag, 416 | START_flag=args.START_flag, 417 | START_p=args.START_p, 418 | START_token_prob=args.START_token_prob, 419 | START_batch_prob=args.START_batch_prob, 420 | START_attention_mode=args.START_attention_mode, 421 | ) 422 | else: 423 | model = create_model( 424 | args.model, 425 | pretrained=False, 426 | num_classes=args.nb_classes, 427 | drop_rate=args.drop, 428 | drop_path_rate=args.drop_path, 429 | drop_block_rate=None, 430 | img_size=args.input_size, 431 | START_flag=args.Vim_START_flag, 432 | ) 433 | 434 | if args.finetune: 435 | if args.finetune.startswith('https'): 436 | checkpoint = torch.hub.load_state_dict_from_url( 437 | args.finetune, map_location='cpu', check_hash=True) 438 | else: 439 | checkpoint = torch.load(args.finetune, map_location='cpu') 440 | # checkpoint = torch.load(args.finetune) 441 | 442 | checkpoint_model = checkpoint['model'] 443 | state_dict = model.state_dict() 444 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 445 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 446 | print(f"Removing key {k} from pretrained checkpoint") 447 | del checkpoint_model[k] 448 | 449 | if "vmamba" in args.model: 450 | pass 451 | else: 452 | # interpolate position embedding 453 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 454 | embedding_size = pos_embed_checkpoint.shape[-1] 455 | num_patches = model.patch_embed.num_patches 456 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 457 | # height (== width) for the checkpoint position embedding 458 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 459 | # height (== width) for the new position embedding 460 | new_size = int(num_patches ** 0.5) 461 | # class_token and dist_token are kept unchanged 462 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 463 | # only the position tokens are interpolated 464 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 465 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 466 | pos_tokens = torch.nn.functional.interpolate( 467 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 468 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 469 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 470 | checkpoint_model['pos_embed'] = new_pos_embed 471 | 472 | model.load_state_dict(checkpoint_model, strict=False) 473 | 474 | if args.attn_only: 475 | for name_p, p in model.named_parameters(): 476 | if '.attn.' in name_p: 477 | p.requires_grad = True 478 | else: 479 | p.requires_grad = False 480 | try: 481 | model.head.weight.requires_grad = True 482 | model.head.bias.requires_grad = True 483 | except: 484 | model.fc.weight.requires_grad = True 485 | model.fc.bias.requires_grad = True 486 | try: 487 | model.pos_embed.requires_grad = True 488 | except: 489 | print('no position encoding') 490 | try: 491 | for p in model.patch_embed.parameters(): 492 | p.requires_grad = False 493 | except: 494 | print('no patch embed') 495 | 496 | # model.half() 497 | model.to(device) 498 | if args.flops_flag == 1: 499 | input = torch.randn(1, 3, 224, 224).to(device) 500 | flops, params = profile(model, inputs=(input, )) 501 | flops, params = clever_format([flops, params], "%.3f") 502 | print(flops, params) 503 | return 504 | 505 | model_ema = None 506 | if args.model_ema: 507 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 508 | model_ema = ModelEma( 509 | model, 510 | decay=args.model_ema_decay, 511 | device='cpu' if args.model_ema_force_cpu else '', 512 | resume='') 513 | 514 | model_without_ddp = model 515 | if args.distributed: 516 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 517 | model_without_ddp = model.module 518 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 519 | print('number of params:', n_parameters) 520 | 521 | if not args.unscale_lr: 522 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 523 | args.lr = linear_scaled_lr 524 | optimizer = create_optimizer(args, model_without_ddp) 525 | 526 | # amp about 527 | amp_autocast = suppress 528 | loss_scaler = "none" 529 | if args.if_amp: 530 | amp_autocast = torch.cuda.amp.autocast 531 | loss_scaler = NativeScaler() 532 | 533 | lr_scheduler, _ = create_scheduler(args, optimizer) 534 | 535 | criterion = LabelSmoothingCrossEntropy() 536 | 537 | if mixup_active: 538 | # smoothing is handled with mixup label transform 539 | criterion = SoftTargetCrossEntropy() 540 | elif args.smoothing: 541 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 542 | else: 543 | criterion = torch.nn.CrossEntropyLoss() 544 | 545 | if args.bce_loss: 546 | criterion = torch.nn.BCEWithLogitsLoss() 547 | 548 | teacher_model = None 549 | if args.distillation_type != 'none': 550 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 551 | print(f"Creating teacher model: {args.teacher_model}") 552 | teacher_model = create_model( 553 | args.teacher_model, 554 | pretrained=False, 555 | num_classes=args.nb_classes, 556 | global_pool='avg', 557 | ) 558 | if args.teacher_path.startswith('https'): 559 | checkpoint = torch.hub.load_state_dict_from_url( 560 | args.teacher_path, map_location='cpu', check_hash=True) 561 | else: 562 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 563 | teacher_model.load_state_dict(checkpoint['model']) 564 | teacher_model.to(device) 565 | teacher_model.eval() 566 | 567 | # wrap the criterion in our custom DistillationLoss, which 568 | # just dispatches to the original criterion if args.distillation_type is 'none' 569 | criterion = DistillationLoss( 570 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 571 | ) 572 | 573 | if "vmamba" in args.model: 574 | if args.START_flag != 0: 575 | args.output_dir += "_START" 576 | 577 | if args.START_attention_mode == 1: 578 | args.output_dir += "_X" 579 | elif args.START_attention_mode == 2: 580 | args.output_dir += "_M" 581 | 582 | if args.spatial_aug_flag == 1: 583 | args.output_dir += "_MixStyle" 584 | elif args.spatial_aug_flag == 2: 585 | args.output_dir += "_DSU" 586 | elif args.spatial_aug_flag == 3: 587 | args.output_dir += "_ALOFT" 588 | 589 | args.output_dir += "_P" + str(args.START_p) 590 | args.output_dir += "_Token" + str(args.START_token_prob) 591 | args.output_dir += "_Batch" + str(args.START_batch_prob) 592 | 593 | args.output_dir += "_lr" + str(args.lr) 594 | else: 595 | if args.Vim_START_flag == 1: 596 | args.output_dir += "START_X" 597 | elif args.Vim_START_flag == 2: 598 | args.output_dir += "START_M" 599 | 600 | if args.resume: 601 | if args.resume.startswith('https'): 602 | checkpoint = torch.hub.load_state_dict_from_url( 603 | args.resume, map_location='cpu', check_hash=True) 604 | else: 605 | checkpoint = torch.load(args.resume, map_location='cpu') 606 | model_without_ddp.load_state_dict(checkpoint['model']) 607 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 608 | optimizer.load_state_dict(checkpoint['optimizer']) 609 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 610 | args.start_epoch = checkpoint['epoch'] + 1 611 | if args.model_ema: 612 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 613 | if 'scaler' in checkpoint and args.if_amp: # change loss_scaler if not amp 614 | loss_scaler.load_state_dict(checkpoint['scaler']) 615 | elif 'scaler' in checkpoint and not args.if_amp: 616 | loss_scaler = 'none' 617 | lr_scheduler.step(args.start_epoch) 618 | 619 | if args.eval == 1: 620 | print("eval") 621 | val_stats = evaluate(data_loader_val, model, device, amp_autocast) 622 | test_stats = evaluate(data_loader_test, model, device, amp_autocast) 623 | print(f"Accuracy of the network on the val images: {val_stats['acc1']:.2f}%") 624 | print(f"Accuracy of the network on the test images: {test_stats['acc1']:.2f}%") 625 | return 626 | 627 | # log about 628 | if args.local_rank == 0 and args.gpu == 0: 629 | mlflow.log_param("n_parameters", n_parameters) 630 | 631 | print(f"Start training for {args.epochs} epochs") 632 | start_time = time.time() 633 | # max_accuracy = 0.0 634 | 635 | output_dir = os.path.join(args.output_dir, args.data, args.target + str(args.seed)) 636 | create_folder(output_dir) 637 | output_dir = Path(output_dir) 638 | 639 | max_accuracy_test = 0.0 640 | max_accuracy_val = 0.0 641 | max_val_test = 0.0 642 | max_val_epoch = 0 643 | max_test_epoch = 0 644 | 645 | for epoch in range(args.start_epoch, args.epochs): 646 | if args.distributed: 647 | data_loader_train.sampler.set_epoch(epoch) 648 | 649 | train_stats = train_one_epoch( 650 | model, criterion, data_loader_train, 651 | optimizer, device, epoch, loss_scaler, amp_autocast, 652 | args.clip_grad, model_ema, mixup_fn, 653 | set_training_mode=args.train_mode, 654 | # keep in eval mode for deit finetuning / train mode for training and deit III finetuning 655 | args=args, 656 | ) 657 | 658 | lr_scheduler.step(epoch) 659 | if args.output_dir: 660 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 661 | for checkpoint_path in checkpoint_paths: 662 | utils.save_on_master({ 663 | 'model': model_without_ddp.state_dict(), 664 | }, checkpoint_path) 665 | 666 | val_stats = evaluate(data_loader_val, model, device, amp_autocast) 667 | test_stats = evaluate(data_loader_test, model, device, amp_autocast) 668 | print(f"Accuracy of the network on the {len(dataset_val)} val images: {val_stats['acc1']:.2f}%") 669 | print(f"Accuracy of the network on the {len(dataset_test)} test images: {test_stats['acc1']:.2f}%") 670 | 671 | max_accuracy_val = max(max_accuracy_val, val_stats["acc1"]) 672 | print(f'Max accuracy val: {max_accuracy_val:.2f}%') 673 | if max_accuracy_val == val_stats["acc1"]: 674 | max_val_epoch = epoch 675 | max_val_test = test_stats['acc1'] 676 | print(f"Corresponding test accuracy: {max_val_test:.2f}%") 677 | 678 | max_accuracy_test = max(max_accuracy_test, test_stats["acc1"]) 679 | if max_accuracy_test == test_stats["acc1"]: 680 | max_test_epoch = epoch 681 | 682 | if max_accuracy_val == val_stats["acc1"]: 683 | if args.output_dir: 684 | checkpoint_paths = [output_dir / 'best_checkpoint.pth'] 685 | for checkpoint_path in checkpoint_paths: 686 | utils.save_on_master({ 687 | 'model': model_without_ddp.state_dict(), 688 | }, checkpoint_path) 689 | 690 | print(f'Max accuracy: {max_val_test:.2f}%') 691 | 692 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 693 | **{f'val_{k}': v for k, v in val_stats.items()}, 694 | **{f'test_{k}': v for k, v in test_stats.items()}, 695 | 'epoch': epoch, 696 | 'n_parameters': n_parameters} 697 | 698 | # log about 699 | if args.local_rank == 0 and args.gpu == 0: 700 | for key, value in log_stats.items(): 701 | mlflow.log_metric(key, value, log_stats['epoch']) 702 | 703 | if args.output_dir and utils.is_main_process(): 704 | with (output_dir / "log.txt").open("a") as f: 705 | f.write(json.dumps(log_stats) + "\n") 706 | 707 | log_stats = {**{f'Best val': max_accuracy_val}, 708 | **{f'Corresponding test': max_val_test}, 709 | **{f'At Epoch': max_val_epoch}, 710 | **{f'Best test': max_accuracy_test}, 711 | **{f'At Epoch': max_test_epoch}, 712 | } 713 | with (output_dir / "log.txt").open("a") as f: 714 | f.write(json.dumps(log_stats) + "\n") 715 | 716 | total_time = time.time() - start_time 717 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 718 | print('Training time {}'.format(total_time_str)) 719 | 720 | 721 | if __name__ == '__main__': 722 | parser = argparse.ArgumentParser('ViM training and evaluation script', parents=[get_args_parser()]) 723 | args = parser.parse_args() 724 | if args.output_dir: 725 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 726 | main(args) 727 | -------------------------------------------------------------------------------- /models/csm_triton.py: -------------------------------------------------------------------------------- 1 | # triton cross scan, 2x speed than pytorch implementation ========================= 2 | import torch 3 | import triton 4 | import triton.language as tl 5 | 6 | 7 | @triton.jit 8 | def triton_cross_scan( 9 | x, # (B, C, H, W) 10 | y, # (B, 4, C, H, W) 11 | BC: tl.constexpr, 12 | BH: tl.constexpr, 13 | BW: tl.constexpr, 14 | DC: tl.constexpr, 15 | DH: tl.constexpr, 16 | DW: tl.constexpr, 17 | NH: tl.constexpr, 18 | NW: tl.constexpr, 19 | ): 20 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) 21 | i_h, i_w = (i_hw // NW), (i_hw % NW) 22 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 23 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 24 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] 25 | _for_C = min(DC - i_c * BC, BC) 26 | 27 | _tmp0 = i_c * BC * DH * DW 28 | _tmp1 = DC * DH * DW 29 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] 30 | p_x = x + i_b * _tmp1 + _tmp2 31 | p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same 32 | p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange( 33 | 0, BH)[:, None] # trans 34 | p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + ( 35 | BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + ( 36 | BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip 37 | p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + ( 38 | BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + ( 39 | BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip 40 | 41 | for idxc in range(_for_C): 42 | _idx = idxc * DH * DW 43 | _x = tl.load(p_x + _idx, mask=_mask_hw) 44 | tl.store(p_y1 + _idx, _x, mask=_mask_hw) 45 | tl.store(p_y2 + _idx, _x, mask=_mask_hw) 46 | tl.store(p_y3 + _idx, _x, mask=_mask_hw) 47 | tl.store(p_y4 + _idx, _x, mask=_mask_hw) 48 | tl.debug_barrier() 49 | 50 | 51 | @triton.jit 52 | def triton_cross_merge( 53 | x, # (B, C, H, W) 54 | y, # (B, 4, C, H, W) 55 | BC: tl.constexpr, 56 | BH: tl.constexpr, 57 | BW: tl.constexpr, 58 | DC: tl.constexpr, 59 | DH: tl.constexpr, 60 | DW: tl.constexpr, 61 | NH: tl.constexpr, 62 | NW: tl.constexpr, 63 | ): 64 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) 65 | i_h, i_w = (i_hw // NW), (i_hw % NW) 66 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 67 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 68 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] 69 | _for_C = min(DC - i_c * BC, BC) 70 | 71 | _tmp0 = i_c * BC * DH * DW 72 | _tmp1 = DC * DH * DW 73 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] 74 | p_x = x + i_b * _tmp1 + _tmp2 75 | p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same 76 | p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange( 77 | 0, BH)[:, None] # trans 78 | p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + ( 79 | BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + ( 80 | BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip 81 | p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + ( 82 | BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + ( 83 | BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip 84 | 85 | for idxc in range(_for_C): 86 | _idx = idxc * DH * DW 87 | _y1 = tl.load(p_y1 + _idx, mask=_mask_hw) 88 | _y2 = tl.load(p_y2 + _idx, mask=_mask_hw) 89 | _y3 = tl.load(p_y3 + _idx, mask=_mask_hw) 90 | _y4 = tl.load(p_y4 + _idx, mask=_mask_hw) 91 | tl.store(p_x + _idx, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) 92 | tl.debug_barrier() 93 | 94 | 95 | @triton.jit 96 | def triton_cross_scan_1b1( 97 | x, # (B, C, H, W) 98 | y, # (B, 4, C, H, W) 99 | BC: tl.constexpr, 100 | BH: tl.constexpr, 101 | BW: tl.constexpr, 102 | DC: tl.constexpr, 103 | DH: tl.constexpr, 104 | DW: tl.constexpr, 105 | NH: tl.constexpr, 106 | NW: tl.constexpr, 107 | ): 108 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) 109 | i_h, i_w = (i_hw // NW), (i_hw % NW) 110 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 111 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 112 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] 113 | _for_C = min(DC - i_c * BC, BC) 114 | 115 | _tmp0 = i_c * BC * DH * DW 116 | _tmp1 = DC * DH * DW 117 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] 118 | p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same 119 | p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange( 120 | 0, BH)[:, None] # trans 121 | p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + ( 122 | BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + ( 123 | BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip 124 | p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + ( 125 | BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + ( 126 | BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip 127 | 128 | p_x1 = x + i_b * 4 * _tmp1 + _tmp2 129 | p_x2 = p_x1 + _tmp1 130 | p_x3 = p_x2 + _tmp1 131 | p_x4 = p_x3 + _tmp1 132 | for idxc in range(_for_C): 133 | _idx = idxc * DH * DW 134 | tl.store(p_y1 + _idx, tl.load(p_x1 + _idx), mask=_mask_hw) 135 | tl.store(p_y2 + _idx, tl.load(p_x2 + _idx), mask=_mask_hw) 136 | tl.store(p_y3 + _idx, tl.load(p_x3 + _idx), mask=_mask_hw) 137 | tl.store(p_y4 + _idx, tl.load(p_x4 + _idx), mask=_mask_hw) 138 | tl.debug_barrier() 139 | 140 | 141 | @triton.jit 142 | def triton_cross_merge_1b1( 143 | x, # (B, C, H, W) 144 | y, # (B, 4, C, H, W) 145 | BC: tl.constexpr, 146 | BH: tl.constexpr, 147 | BW: tl.constexpr, 148 | DC: tl.constexpr, 149 | DH: tl.constexpr, 150 | DW: tl.constexpr, 151 | NH: tl.constexpr, 152 | NW: tl.constexpr, 153 | ): 154 | i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) 155 | i_h, i_w = (i_hw // NW), (i_hw % NW) 156 | _mask_h = (i_h * BH + tl.arange(0, BH)) < DH 157 | _mask_w = (i_w * BW + tl.arange(0, BW)) < DW 158 | _mask_hw = _mask_h[:, None] & _mask_w[None, :] 159 | _for_C = min(DC - i_c * BC, BC) 160 | 161 | _tmp0 = i_c * BC * DH * DW 162 | _tmp1 = DC * DH * DW 163 | _tmp2 = _tmp0 + i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] 164 | p_y1 = y + i_b * 4 * _tmp1 + _tmp2 # same 165 | p_y2 = y + i_b * 4 * _tmp1 + _tmp1 + _tmp0 + i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange( 166 | 0, BH)[:, None] # trans 167 | p_y3 = y + i_b * 4 * _tmp1 + 2 * _tmp1 + _tmp0 + (NH - i_h - 1) * BH * DW + ( 168 | BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + ( 169 | BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) # flip 170 | p_y4 = y + i_b * 4 * _tmp1 + 3 * _tmp1 + _tmp0 + (NW - i_w - 1) * BW * DH + ( 171 | BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + ( 172 | BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH # trans + flip 173 | 174 | p_x1 = x + i_b * 4 * _tmp1 + _tmp2 175 | p_x2 = p_x1 + _tmp1 176 | p_x3 = p_x2 + _tmp1 177 | p_x4 = p_x3 + _tmp1 178 | for idxc in range(_for_C): 179 | _idx = idxc * DH * DW 180 | tl.store(p_x1 + _idx, tl.load(p_y1 + _idx), mask=_mask_hw) 181 | tl.store(p_x2 + _idx, tl.load(p_y2 + _idx), mask=_mask_hw) 182 | tl.store(p_x3 + _idx, tl.load(p_y3 + _idx), mask=_mask_hw) 183 | tl.store(p_x4 + _idx, tl.load(p_y4 + _idx), mask=_mask_hw) 184 | tl.debug_barrier() 185 | 186 | 187 | class CrossScanTriton(torch.autograd.Function): 188 | @staticmethod 189 | def forward(ctx, x: torch.Tensor): 190 | B, C, H, W = x.shape 191 | B, C, H, W = int(B), int(C), int(H), int(W) 192 | BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min( 193 | triton.next_power_of_2(W), 64) 194 | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) 195 | ctx.shape = (B, C, H, W) 196 | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) 197 | x = x.contiguous() 198 | y = x.new_empty((B, 4, C, H, W)) 199 | triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 200 | return y.view(B, 4, C, -1) 201 | 202 | @staticmethod 203 | def backward(ctx, y: torch.Tensor): 204 | # out: (b, k, d, l) 205 | B, C, H, W = ctx.shape 206 | BC, BH, BW, NC, NH, NW = ctx.triton_shape 207 | y = y.contiguous().view(B, 4, C, H, W) 208 | x = y.new_empty((B, C, H, W)) 209 | triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 210 | return x 211 | 212 | 213 | class CrossMergeTriton(torch.autograd.Function): 214 | @staticmethod 215 | def forward(ctx, y: torch.Tensor): 216 | B, K, C, H, W = y.shape 217 | B, C, H, W = int(B), int(C), int(H), int(W) 218 | BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min( 219 | triton.next_power_of_2(W), 64) 220 | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) 221 | ctx.shape = (B, C, H, W) 222 | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) 223 | y = y.contiguous().view(B, 4, C, H, W) 224 | x = y.new_empty((B, C, H, W)) 225 | triton_cross_merge[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 226 | return x.view(B, C, -1) 227 | 228 | @staticmethod 229 | def backward(ctx, x: torch.Tensor): 230 | # out: (b, d, l) 231 | B, C, H, W = ctx.shape 232 | BC, BH, BW, NC, NH, NW = ctx.triton_shape 233 | x = x.contiguous() 234 | y = x.new_empty((B, 4, C, H, W)) 235 | triton_cross_scan[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 236 | return y 237 | 238 | 239 | class CrossScanTriton1b1(torch.autograd.Function): 240 | @staticmethod 241 | def forward(ctx, x: torch.Tensor): 242 | B, K, C, H, W = x.shape 243 | B, C, H, W = int(B), int(C), int(H), int(W) 244 | BC, BH, BW = min(triton.next_power_of_2(C), 1), min(triton.next_power_of_2(H), 64), min( 245 | triton.next_power_of_2(W), 64) 246 | NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) 247 | ctx.shape = (B, C, H, W) 248 | ctx.triton_shape = (BC, BH, BW, NC, NH, NW) 249 | x = x.contiguous() 250 | y = x.new_empty((B, 4, C, H, W)) 251 | triton_cross_scan_1b1[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 252 | return y.view(B, 4, C, -1) 253 | 254 | @staticmethod 255 | def backward(ctx, y: torch.Tensor): 256 | # out: (b, k, d, l) 257 | B, C, H, W = ctx.shape 258 | BC, BH, BW, NC, NH, NW = ctx.triton_shape 259 | y = y.contiguous().view(B, 4, C, H, W) 260 | x = y.new_empty((B, 4, C, H, W)) 261 | triton_cross_merge_1b1[(NH * NW, NC, B)](x, y, BC, BH, BW, C, H, W, NH, NW) 262 | return x 263 | 264 | -------------------------------------------------------------------------------- /models/mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from einops import rearrange, repeat 12 | 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn 14 | 15 | try: 16 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 17 | except ImportError: 18 | causal_conv1d_fn, causal_conv1d_update = None 19 | 20 | try: 21 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update 22 | except ImportError: 23 | selective_state_update = None 24 | 25 | try: 26 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 27 | except ImportError: 28 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 29 | 30 | 31 | from perturb_style.MixStyle import MixStyle 32 | from perturb_style.SeqTokenAug import SeqTokenAug 33 | 34 | class Mamba(nn.Module): 35 | def __init__( 36 | self, 37 | d_model, 38 | d_state=16, 39 | d_conv=4, 40 | expand=2, 41 | dt_rank="auto", 42 | dt_min=0.001, 43 | dt_max=0.1, 44 | dt_init="random", 45 | dt_scale=1.0, 46 | dt_init_floor=1e-4, 47 | conv_bias=True, 48 | bias=False, 49 | use_fast_path=False, # Fused kernel options 50 | layer_idx=None, 51 | device=None, 52 | dtype=None, 53 | START_flag=0, 54 | ): 55 | factory_kwargs = {"device": device, "dtype": dtype} 56 | super().__init__() 57 | self.d_model = d_model 58 | self.d_state = d_state 59 | self.d_conv = d_conv 60 | self.expand = expand 61 | self.d_inner = int(self.expand * self.d_model) 62 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 63 | self.use_fast_path = use_fast_path 64 | self.layer_idx = layer_idx 65 | 66 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 67 | 68 | self.conv1d = nn.Conv1d( 69 | in_channels=self.d_inner, 70 | out_channels=self.d_inner, 71 | bias=conv_bias, 72 | kernel_size=d_conv, 73 | groups=self.d_inner, 74 | padding=d_conv - 1, 75 | **factory_kwargs, 76 | ) 77 | 78 | self.activation = "silu" 79 | self.act = nn.SiLU() 80 | 81 | self.x_proj = nn.Linear( 82 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 83 | ) 84 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 85 | 86 | # Initialize special dt projection to preserve variance at initialization 87 | dt_init_std = self.dt_rank**-0.5 * dt_scale 88 | if dt_init == "constant": 89 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 90 | elif dt_init == "random": 91 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 92 | else: 93 | raise NotImplementedError 94 | 95 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 96 | dt = torch.exp( 97 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 98 | + math.log(dt_min) 99 | ).clamp(min=dt_init_floor) 100 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 101 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 102 | with torch.no_grad(): 103 | self.dt_proj.bias.copy_(inv_dt) 104 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 105 | self.dt_proj.bias._no_reinit = True 106 | 107 | # S4D real initialization 108 | A = repeat( 109 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 110 | "n -> d n", 111 | d=self.d_inner, 112 | ).contiguous() 113 | A_log = torch.log(A) # Keep A_log in fp32 114 | self.A_log = nn.Parameter(A_log) 115 | self.A_log._no_weight_decay = True 116 | 117 | # D "skip" parameter 118 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 119 | self.D._no_weight_decay = True 120 | 121 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 122 | 123 | self.START_flag = START_flag 124 | if self.START_flag != 0: 125 | self.spatial_aug = MixStyle(p=1.0) 126 | self.SeqTokenAug = SeqTokenAug( 127 | p=1.0, 128 | aug_token_prob=0.75, 129 | batch_prob=0.5, 130 | token_or_seq=0, 131 | back_fill_mode=0, 132 | token_attention_flag=1, 133 | seq_token_flag=0, 134 | ) 135 | 136 | def forward(self, hidden_states, inference_params=None): 137 | """ 138 | hidden_states: (B, L, D) 139 | Returns: same shape as hidden_states 140 | """ 141 | batch, seqlen, dim = hidden_states.shape 142 | 143 | conv_state, ssm_state = None, None 144 | if inference_params is not None: 145 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 146 | if inference_params.seqlen_offset > 0: 147 | # The states are updated inplace 148 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 149 | return out 150 | 151 | # We do matmul and transpose BLH -> HBL at the same time 152 | xz = rearrange( 153 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 154 | "d (b l) -> b d l", 155 | l=seqlen, 156 | ) 157 | if self.in_proj.bias is not None: 158 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 159 | 160 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 161 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 162 | if self.use_fast_path and inference_params is None: # Doesn't support outputting the states 163 | out = mamba_inner_fn( 164 | xz, 165 | self.conv1d.weight, 166 | self.conv1d.bias, 167 | self.x_proj.weight, 168 | self.dt_proj.weight, 169 | self.out_proj.weight, 170 | self.out_proj.bias, 171 | A, 172 | None, # input-dependent B 173 | None, # input-dependent C 174 | self.D.float(), 175 | delta_bias=self.dt_proj.bias.float(), 176 | delta_softplus=True, 177 | ) 178 | else: 179 | x, z = xz.chunk(2, dim=1) 180 | # Compute short convolution 181 | if conv_state is not None: 182 | # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv 183 | # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. 184 | conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) 185 | if causal_conv1d_fn is None: 186 | x = self.act(self.conv1d(x)[..., :seqlen]) 187 | else: 188 | assert self.activation in ["silu", "swish"] 189 | x = causal_conv1d_fn( 190 | x=x, 191 | weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), 192 | bias=self.conv1d.bias, 193 | activation=self.activation, 194 | ) 195 | 196 | if self.training: 197 | if self.START_flag == 1: 198 | x_bkdl = x.unsqueeze(dim=1) 199 | x_aug = self.spatial_aug(x_bkdl) 200 | x = self.SeqTokenAug(x_bkdl, x_aug, Bx=x_bkdl) 201 | x = x.squeeze() 202 | elif self.START_flag == 2: 203 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 204 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 205 | dt = self.dt_proj.weight @ dt.t() 206 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 207 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 208 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 209 | deltaB_u = torch.einsum('bnl,bdl,bnl,bdl->bdln', C, dt, B, x) 210 | 211 | x_bkdl = x.unsqueeze(dim=1) 212 | deltaB_u = deltaB_u.unsqueeze(dim=1) 213 | 214 | x_aug = self.spatial_aug(x_bkdl) 215 | x = self.SeqTokenAug(x_bkdl, x_aug, Bx=deltaB_u.mean(dim=-1)) 216 | x = x.squeeze() 217 | 218 | # We're careful here about the layout, to avoid extra transposes. 219 | # We want dt to have d as the slowest moving dimension 220 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 221 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 222 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 223 | dt = self.dt_proj.weight @ dt.t() 224 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 225 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 226 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 227 | 228 | assert self.activation in ["silu", "swish"] 229 | y = selective_scan_fn( 230 | x, 231 | dt, 232 | A, 233 | B, 234 | C, 235 | self.D.float(), 236 | z=z, 237 | delta_bias=self.dt_proj.bias.float(), 238 | delta_softplus=True, 239 | return_last_state=ssm_state is not None, 240 | ) 241 | if ssm_state is not None: 242 | y, last_state = y 243 | ssm_state.copy_(last_state) 244 | y = rearrange(y, "b d l -> b l d") 245 | out = self.out_proj(y) 246 | return out 247 | 248 | def step(self, hidden_states, conv_state, ssm_state): 249 | dtype = hidden_states.dtype 250 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 251 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 252 | x, z = xz.chunk(2, dim=-1) # (B D) 253 | 254 | # Conv step 255 | if causal_conv1d_update is None: 256 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 257 | conv_state[:, :, -1] = x 258 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 259 | if self.conv1d.bias is not None: 260 | x = x + self.conv1d.bias 261 | x = self.act(x).to(dtype=dtype) 262 | else: 263 | x = causal_conv1d_update( 264 | x, 265 | conv_state, 266 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 267 | self.conv1d.bias, 268 | self.activation, 269 | ) 270 | 271 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 272 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 273 | # Don't add dt_bias here 274 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 275 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 276 | 277 | # SSM step 278 | if selective_state_update is None: 279 | # Discretize A and B 280 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 281 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 282 | dB = torch.einsum("bd,bn->bdn", dt, B) 283 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 284 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 285 | y = y + self.D.to(dtype) * x 286 | y = y * self.act(z) # (B D) 287 | else: 288 | y = selective_state_update( 289 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 290 | ) 291 | 292 | out = self.out_proj(y) 293 | return out.unsqueeze(1), conv_state, ssm_state 294 | 295 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 296 | device = self.out_proj.weight.device 297 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 298 | conv_state = torch.zeros( 299 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 300 | ) 301 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 302 | # ssm_dtype = torch.float32 303 | ssm_state = torch.zeros( 304 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 305 | ) 306 | return conv_state, ssm_state 307 | 308 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 309 | assert self.layer_idx is not None 310 | if self.layer_idx not in inference_params.key_value_memory_dict: 311 | batch_shape = (batch_size,) 312 | conv_state = torch.zeros( 313 | batch_size, 314 | self.d_model * self.expand, 315 | self.d_conv, 316 | device=self.conv1d.weight.device, 317 | dtype=self.conv1d.weight.dtype, 318 | ) 319 | ssm_state = torch.zeros( 320 | batch_size, 321 | self.d_model * self.expand, 322 | self.d_state, 323 | device=self.dt_proj.weight.device, 324 | dtype=self.dt_proj.weight.dtype, 325 | # dtype=torch.float32, 326 | ) 327 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 328 | else: 329 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 330 | # TODO: What if batch size changes between generation, and we reuse the same states? 331 | if initialize_states: 332 | conv_state.zero_() 333 | ssm_state.zero_() 334 | return conv_state, ssm_state 335 | 336 | 337 | class Block(nn.Module): 338 | def __init__( 339 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 340 | ): 341 | """ 342 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 343 | 344 | This Block has a slightly different structure compared to a regular 345 | prenorm Transformer block. 346 | The standard block is: LN -> MHA/MLP -> Add. 347 | [Ref: https://arxiv.org/abs/2002.04745] 348 | Here we have: Add -> LN -> Mixer, returning both 349 | the hidden_states (output of the mixer) and the residual. 350 | This is purely for performance reasons, as we can fuse add and LayerNorm. 351 | The residual needs to be provided (except for the very first block). 352 | """ 353 | super().__init__() 354 | self.residual_in_fp32 = residual_in_fp32 355 | self.fused_add_norm = fused_add_norm 356 | self.mixer = mixer_cls(dim) 357 | self.norm = norm_cls(dim) 358 | if self.fused_add_norm: 359 | assert RMSNorm is not None, "RMSNorm import fails" 360 | assert isinstance( 361 | self.norm, (nn.LayerNorm, RMSNorm) 362 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 363 | 364 | def forward( 365 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 366 | ): 367 | r"""Pass the input through the encoder layer. 368 | 369 | Args: 370 | hidden_states: the sequence to the encoder layer (required). 371 | residual: hidden_states = Mixer(LN(residual)) 372 | """ 373 | if not self.fused_add_norm: 374 | residual = (hidden_states + residual) if residual is not None else hidden_states 375 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 376 | if self.residual_in_fp32: 377 | residual = residual.to(torch.float32) 378 | else: 379 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 380 | hidden_states, residual = fused_add_norm_fn( 381 | hidden_states, 382 | self.norm.weight, 383 | self.norm.bias, 384 | residual=residual, 385 | prenorm=True, 386 | residual_in_fp32=self.residual_in_fp32, 387 | eps=self.norm.eps, 388 | ) 389 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 390 | return hidden_states, residual 391 | 392 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 393 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 394 | -------------------------------------------------------------------------------- /models/mamba_ssm/mamba_simple_vim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from einops import rearrange, repeat 12 | 13 | try: 14 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 15 | except ImportError: 16 | causal_conv1d_fn, causal_conv1d_update = None 17 | 18 | # try: 19 | # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj 20 | # except ImportError: 21 | # selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None 22 | 23 | from .selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj 24 | 25 | # try: 26 | # from selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj 27 | # except ImportError: 28 | # selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None 29 | 30 | 31 | try: 32 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update 33 | except ImportError: 34 | selective_state_update = None 35 | 36 | try: 37 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 38 | except ImportError: 39 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 40 | 41 | class Mamba(nn.Module): 42 | def __init__( 43 | self, 44 | d_model, 45 | d_state=16, 46 | d_conv=4, 47 | expand=2, 48 | dt_rank="auto", 49 | dt_min=0.001, 50 | dt_max=0.1, 51 | dt_init="random", 52 | dt_scale=1.0, 53 | dt_init_floor=1e-4, 54 | conv_bias=True, 55 | bias=False, 56 | use_fast_path=True, # Fused kernel options 57 | layer_idx=None, 58 | device=None, 59 | dtype=None, 60 | bimamba_type="none", 61 | if_devide_out=False, 62 | init_layer_scale=None, 63 | START_flag=0, 64 | ): 65 | factory_kwargs = {"device": device, "dtype": dtype} 66 | super().__init__() 67 | self.d_model = d_model 68 | self.d_state = d_state 69 | self.d_conv = d_conv 70 | self.expand = expand 71 | self.d_inner = int(self.expand * self.d_model) 72 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 73 | self.use_fast_path = use_fast_path 74 | self.layer_idx = layer_idx 75 | self.bimamba_type = bimamba_type 76 | self.if_devide_out = if_devide_out 77 | 78 | self.init_layer_scale = init_layer_scale 79 | if init_layer_scale is not None: 80 | self.gamma = nn.Parameter(init_layer_scale * torch.ones((d_model)), requires_grad=True) 81 | 82 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 83 | 84 | self.conv1d = nn.Conv1d( 85 | in_channels=self.d_inner, 86 | out_channels=self.d_inner, 87 | bias=conv_bias, 88 | kernel_size=d_conv, 89 | groups=self.d_inner, 90 | padding=d_conv - 1, 91 | **factory_kwargs, 92 | ) 93 | 94 | self.activation = "silu" 95 | self.act = nn.SiLU() 96 | 97 | self.x_proj = nn.Linear( 98 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 99 | ) 100 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 101 | 102 | # Initialize special dt projection to preserve variance at initialization 103 | dt_init_std = self.dt_rank**-0.5 * dt_scale 104 | if dt_init == "constant": 105 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 106 | elif dt_init == "random": 107 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 108 | else: 109 | raise NotImplementedError 110 | 111 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 112 | dt = torch.exp( 113 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 114 | + math.log(dt_min) 115 | ).clamp(min=dt_init_floor) 116 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 117 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 118 | with torch.no_grad(): 119 | self.dt_proj.bias.copy_(inv_dt) 120 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 121 | self.dt_proj.bias._no_reinit = True 122 | 123 | # S4D real initialization 124 | A = repeat( 125 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 126 | "n -> d n", 127 | d=self.d_inner, 128 | ).contiguous() 129 | A_log = torch.log(A) # Keep A_log in fp32 130 | self.A_log = nn.Parameter(A_log) 131 | self.A_log._no_weight_decay = True 132 | 133 | # D "skip" parameter 134 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 135 | self.D._no_weight_decay = True 136 | 137 | # bidirectional 138 | if bimamba_type == "v1": 139 | A_b = repeat( 140 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 141 | "n -> d n", 142 | d=self.d_inner, 143 | ).contiguous() 144 | A_b_log = torch.log(A_b) # Keep A_b_log in fp32 145 | self.A_b_log = nn.Parameter(A_b_log) 146 | self.A_b_log._no_weight_decay = True 147 | elif bimamba_type == "v2": 148 | A_b = repeat( 149 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 150 | "n -> d n", 151 | d=self.d_inner, 152 | ).contiguous() 153 | A_b_log = torch.log(A_b) # Keep A_b_log in fp32 154 | self.A_b_log = nn.Parameter(A_b_log) 155 | self.A_b_log._no_weight_decay = True 156 | 157 | self.conv1d_b = nn.Conv1d( 158 | in_channels=self.d_inner, 159 | out_channels=self.d_inner, 160 | bias=conv_bias, 161 | kernel_size=d_conv, 162 | groups=self.d_inner, 163 | padding=d_conv - 1, 164 | **factory_kwargs, 165 | ) 166 | 167 | self.x_proj_b = nn.Linear( 168 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 169 | ) 170 | self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 171 | 172 | self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 173 | self.D_b._no_weight_decay = True 174 | 175 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 176 | 177 | self.START_flag = START_flag 178 | 179 | def forward(self, hidden_states, inference_params=None): 180 | """ 181 | hidden_states: (B, L, D) 182 | Returns: same shape as hidden_states 183 | """ 184 | batch, seqlen, dim = hidden_states.shape 185 | 186 | conv_state, ssm_state = None, None 187 | if inference_params is not None: 188 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 189 | if inference_params.seqlen_offset > 0: 190 | # The states are updated inplace 191 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 192 | return out 193 | 194 | # We do matmul and transpose BLH -> HBL at the same time 195 | xz = rearrange( 196 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 197 | "d (b l) -> b d l", 198 | l=seqlen, 199 | ) 200 | if self.in_proj.bias is not None: 201 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 202 | 203 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 204 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 205 | if self.use_fast_path and inference_params is None: # Doesn't support outputting the states 206 | if self.bimamba_type == "v1": 207 | A_b = -torch.exp(self.A_b_log.float()) 208 | out = bimamba_inner_fn( 209 | xz, 210 | self.conv1d.weight, 211 | self.conv1d.bias, 212 | self.x_proj.weight, 213 | self.dt_proj.weight, 214 | self.out_proj.weight, 215 | self.out_proj.bias, 216 | A, 217 | A_b, 218 | None, # input-dependent B 219 | None, # input-dependent C 220 | self.D.float(), 221 | delta_bias=self.dt_proj.bias.float(), 222 | delta_softplus=True, 223 | ) 224 | elif self.bimamba_type == "v2": 225 | A_b = -torch.exp(self.A_b_log.float()) 226 | out = mamba_inner_fn_no_out_proj( 227 | xz, 228 | self.conv1d.weight, 229 | self.conv1d.bias, 230 | self.x_proj.weight, 231 | self.dt_proj.weight, 232 | A, 233 | None, # input-dependent B 234 | None, # input-dependent C 235 | self.D.float(), 236 | delta_bias=self.dt_proj.bias.float(), 237 | delta_softplus=True, 238 | START_flag=self.START_flag, 239 | training_flag=self.training, 240 | ) 241 | out_b = mamba_inner_fn_no_out_proj( 242 | xz.flip([-1]), 243 | self.conv1d_b.weight, 244 | self.conv1d_b.bias, 245 | self.x_proj_b.weight, 246 | self.dt_proj_b.weight, 247 | A_b, 248 | None, 249 | None, 250 | self.D_b.float(), 251 | delta_bias=self.dt_proj_b.bias.float(), 252 | delta_softplus=True, 253 | START_flag=self.START_flag, 254 | training_flag=self.training, 255 | ) 256 | # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) 257 | if not self.if_devide_out: 258 | out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) 259 | else: 260 | out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias) 261 | 262 | else: 263 | out = mamba_inner_fn( 264 | xz, 265 | self.conv1d.weight, 266 | self.conv1d.bias, 267 | self.x_proj.weight, 268 | self.dt_proj.weight, 269 | self.out_proj.weight, 270 | self.out_proj.bias, 271 | A, 272 | None, # input-dependent B 273 | None, # input-dependent C 274 | self.D.float(), 275 | delta_bias=self.dt_proj.bias.float(), 276 | delta_softplus=True, 277 | ) 278 | else: 279 | x, z = xz.chunk(2, dim=1) 280 | # Compute short convolution 281 | if conv_state is not None: 282 | # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv 283 | # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. 284 | conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) 285 | if causal_conv1d_fn is None: 286 | x = self.act(self.conv1d(x)[..., :seqlen]) 287 | else: 288 | assert self.activation in ["silu", "swish"] 289 | x = causal_conv1d_fn( 290 | x=x, 291 | weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), 292 | bias=self.conv1d.bias, 293 | activation=self.activation, 294 | ) 295 | 296 | # We're careful here about the layout, to avoid extra transposes. 297 | # We want dt to have d as the slowest moving dimension 298 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 299 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 300 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 301 | dt = self.dt_proj.weight @ dt.t() 302 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 303 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 304 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 305 | assert self.activation in ["silu", "swish"] 306 | 307 | y = selective_scan_fn( 308 | x, 309 | dt, 310 | A, 311 | B, 312 | C, 313 | self.D.float(), 314 | z=z, 315 | delta_bias=self.dt_proj.bias.float(), 316 | delta_softplus=True, 317 | return_last_state=ssm_state is not None, 318 | ) 319 | if ssm_state is not None: 320 | y, last_state = y 321 | ssm_state.copy_(last_state) 322 | y = rearrange(y, "b d l -> b l d") 323 | out = self.out_proj(y) 324 | if self.init_layer_scale is not None: 325 | out = out * self.gamma 326 | return out 327 | 328 | def step(self, hidden_states, conv_state, ssm_state): 329 | dtype = hidden_states.dtype 330 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 331 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 332 | x, z = xz.chunk(2, dim=-1) # (B D) 333 | 334 | # Conv step 335 | if causal_conv1d_update is None: 336 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 337 | conv_state[:, :, -1] = x 338 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 339 | if self.conv1d.bias is not None: 340 | x = x + self.conv1d.bias 341 | x = self.act(x).to(dtype=dtype) 342 | else: 343 | x = causal_conv1d_update( 344 | x, 345 | conv_state, 346 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 347 | self.conv1d.bias, 348 | self.activation, 349 | ) 350 | 351 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 352 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 353 | # Don't add dt_bias here 354 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 355 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 356 | 357 | # SSM step 358 | if selective_state_update is None: 359 | # Discretize A and B 360 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 361 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 362 | dB = torch.einsum("bd,bn->bdn", dt, B) 363 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 364 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 365 | y = y + self.D.to(dtype) * x 366 | y = y * self.act(z) # (B D) 367 | else: 368 | y = selective_state_update( 369 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 370 | ) 371 | 372 | out = self.out_proj(y) 373 | return out.unsqueeze(1), conv_state, ssm_state 374 | 375 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 376 | device = self.out_proj.weight.device 377 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 378 | conv_state = torch.zeros( 379 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 380 | ) 381 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 382 | # ssm_dtype = torch.float32 383 | ssm_state = torch.zeros( 384 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 385 | ) 386 | return conv_state, ssm_state 387 | 388 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 389 | assert self.layer_idx is not None 390 | if self.layer_idx not in inference_params.key_value_memory_dict: 391 | batch_shape = (batch_size,) 392 | conv_state = torch.zeros( 393 | batch_size, 394 | self.d_model * self.expand, 395 | self.d_conv, 396 | device=self.conv1d.weight.device, 397 | dtype=self.conv1d.weight.dtype, 398 | ) 399 | ssm_state = torch.zeros( 400 | batch_size, 401 | self.d_model * self.expand, 402 | self.d_state, 403 | device=self.dt_proj.weight.device, 404 | dtype=self.dt_proj.weight.dtype, 405 | # dtype=torch.float32, 406 | ) 407 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 408 | else: 409 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 410 | # TODO: What if batch size changes between generation, and we reuse the same states? 411 | if initialize_states: 412 | conv_state.zero_() 413 | ssm_state.zero_() 414 | return conv_state, ssm_state 415 | 416 | 417 | class Block(nn.Module): 418 | def __init__( 419 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 420 | ): 421 | """ 422 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 423 | 424 | This Block has a slightly different structure compared to a regular 425 | prenorm Transformer block. 426 | The standard block is: LN -> MHA/MLP -> Add. 427 | [Ref: https://arxiv.org/abs/2002.04745] 428 | Here we have: Add -> LN -> Mixer, returning both 429 | the hidden_states (output of the mixer) and the residual. 430 | This is purely for performance reasons, as we can fuse add and LayerNorm. 431 | The residual needs to be provided (except for the very first block). 432 | """ 433 | super().__init__() 434 | self.residual_in_fp32 = residual_in_fp32 435 | self.fused_add_norm = fused_add_norm 436 | self.mixer = mixer_cls(dim) 437 | self.norm = norm_cls(dim) 438 | if self.fused_add_norm: 439 | assert RMSNorm is not None, "RMSNorm import fails" 440 | assert isinstance( 441 | self.norm, (nn.LayerNorm, RMSNorm) 442 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 443 | 444 | def forward( 445 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 446 | ): 447 | r"""Pass the input through the encoder layer. 448 | 449 | Args: 450 | hidden_states: the sequence to the encoder layer (required). 451 | residual: hidden_states = Mixer(LN(residual)) 452 | """ 453 | if not self.fused_add_norm: 454 | residual = (hidden_states + residual) if residual is not None else hidden_states 455 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 456 | if self.residual_in_fp32: 457 | residual = residual.to(torch.float32) 458 | else: 459 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 460 | hidden_states, residual = fused_add_norm_fn( 461 | hidden_states, 462 | self.norm.weight, 463 | self.norm.bias, 464 | residual=residual, 465 | prenorm=True, 466 | residual_in_fp32=self.residual_in_fp32, 467 | eps=self.norm.eps, 468 | ) 469 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 470 | return hidden_states, residual 471 | 472 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 473 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 474 | -------------------------------------------------------------------------------- /models_mamba.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | from torch import Tensor 7 | from typing import Optional 8 | 9 | from timm.models.vision_transformer import VisionTransformer, _cfg 10 | from timm.models.registry import register_model 11 | from timm.models.layers import trunc_normal_, lecun_normal_ 12 | 13 | from timm.models.layers import DropPath, to_2tuple 14 | from timm.models.vision_transformer import _load_weights 15 | 16 | import math 17 | 18 | from collections import namedtuple 19 | # from mamba_ssm.modules.mamba_simple import Mamba 20 | # from models.mamba_simple_vim import Mamba 21 | from models.mamba_ssm.mamba_simple_vim import Mamba 22 | from mamba_ssm.utils.generation import GenerationMixin 23 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 24 | 25 | from rope import * 26 | import random 27 | 28 | try: 29 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 30 | except ImportError: 31 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 32 | 33 | 34 | __all__ = [ 35 | 'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224', 36 | 'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384', 37 | ] 38 | 39 | 40 | class PatchEmbed(nn.Module): 41 | """ 2D Image to Patch Embedding 42 | """ 43 | def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 44 | super().__init__() 45 | img_size = to_2tuple(img_size) 46 | patch_size = to_2tuple(patch_size) 47 | self.img_size = img_size 48 | self.patch_size = patch_size 49 | self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1) 50 | self.num_patches = self.grid_size[0] * self.grid_size[1] 51 | self.flatten = flatten 52 | 53 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) 54 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 55 | 56 | def forward(self, x): 57 | B, C, H, W = x.shape 58 | assert H == self.img_size[0] and W == self.img_size[1], \ 59 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 60 | x = self.proj(x) 61 | if self.flatten: 62 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 63 | x = self.norm(x) 64 | return x 65 | 66 | 67 | class Block(nn.Module): 68 | def __init__( 69 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0., 70 | ): 71 | """ 72 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 73 | 74 | This Block has a slightly different structure compared to a regular 75 | prenorm Transformer block. 76 | The standard block is: LN -> MHA/MLP -> Add. 77 | [Ref: https://arxiv.org/abs/2002.04745] 78 | Here we have: Add -> LN -> Mixer, returning both 79 | the hidden_states (output of the mixer) and the residual. 80 | This is purely for performance reasons, as we can fuse add and LayerNorm. 81 | The residual needs to be provided (except for the very first block). 82 | """ 83 | super().__init__() 84 | self.residual_in_fp32 = residual_in_fp32 85 | self.fused_add_norm = fused_add_norm 86 | self.mixer = mixer_cls(dim) 87 | self.norm = norm_cls(dim) 88 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 89 | if self.fused_add_norm: 90 | assert RMSNorm is not None, "RMSNorm import fails" 91 | assert isinstance( 92 | self.norm, (nn.LayerNorm, RMSNorm) 93 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 94 | 95 | def forward( 96 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 97 | ): 98 | r"""Pass the input through the encoder layer. 99 | 100 | Args: 101 | hidden_states: the sequence to the encoder layer (required). 102 | residual: hidden_states = Mixer(LN(residual)) 103 | """ 104 | if not self.fused_add_norm: 105 | if residual is None: 106 | residual = hidden_states 107 | else: 108 | residual = residual + self.drop_path(hidden_states) 109 | 110 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 111 | if self.residual_in_fp32: 112 | residual = residual.to(torch.float32) 113 | else: 114 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 115 | if residual is None: 116 | hidden_states, residual = fused_add_norm_fn( 117 | hidden_states, 118 | self.norm.weight, 119 | self.norm.bias, 120 | residual=residual, 121 | prenorm=True, 122 | residual_in_fp32=self.residual_in_fp32, 123 | eps=self.norm.eps, 124 | ) 125 | else: 126 | hidden_states, residual = fused_add_norm_fn( 127 | self.drop_path(hidden_states), 128 | self.norm.weight, 129 | self.norm.bias, 130 | residual=residual, 131 | prenorm=True, 132 | residual_in_fp32=self.residual_in_fp32, 133 | eps=self.norm.eps, 134 | ) 135 | 136 | # augment x 137 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 138 | return hidden_states, residual 139 | 140 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 141 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 142 | 143 | 144 | def create_block( 145 | d_model, 146 | ssm_cfg=None, 147 | norm_epsilon=1e-5, 148 | drop_path=0., 149 | rms_norm=False, 150 | residual_in_fp32=False, 151 | fused_add_norm=False, 152 | layer_idx=None, 153 | device=None, 154 | dtype=None, 155 | if_bimamba=False, 156 | bimamba_type="none", 157 | if_devide_out=False, 158 | init_layer_scale=None, 159 | START_flag=0, 160 | ): 161 | if if_bimamba: 162 | bimamba_type = "v1" 163 | if ssm_cfg is None: 164 | ssm_cfg = {} 165 | factory_kwargs = {"device": device, "dtype": dtype} 166 | mixer_cls = partial(Mamba, START_flag=START_flag, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs) 167 | # mixer_cls = partial(Mamba, START_flag=START_flag, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs) 168 | norm_cls = partial( 169 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 170 | ) 171 | block = Block( 172 | d_model, 173 | mixer_cls, 174 | norm_cls=norm_cls, 175 | drop_path=drop_path, 176 | fused_add_norm=fused_add_norm, 177 | residual_in_fp32=residual_in_fp32, 178 | ) 179 | block.layer_idx = layer_idx 180 | return block 181 | 182 | 183 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 184 | def _init_weights( 185 | module, 186 | n_layer, 187 | initializer_range=0.02, # Now only used for embedding layer. 188 | rescale_prenorm_residual=True, 189 | n_residuals_per_layer=1, # Change to 2 if we have MLP 190 | ): 191 | if isinstance(module, nn.Linear): 192 | if module.bias is not None: 193 | if not getattr(module.bias, "_no_reinit", False): 194 | nn.init.zeros_(module.bias) 195 | elif isinstance(module, nn.Embedding): 196 | nn.init.normal_(module.weight, std=initializer_range) 197 | 198 | if rescale_prenorm_residual: 199 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 200 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 201 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 202 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 203 | # 204 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 205 | for name, p in module.named_parameters(): 206 | if name in ["out_proj.weight", "fc2.weight"]: 207 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 208 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 209 | # We need to reinit p since this code could be called multiple times 210 | # Having just p *= scale would repeatedly scale it down 211 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 212 | with torch.no_grad(): 213 | p /= math.sqrt(n_residuals_per_layer * n_layer) 214 | 215 | 216 | def segm_init_weights(m): 217 | if isinstance(m, nn.Linear): 218 | trunc_normal_(m.weight, std=0.02) 219 | if isinstance(m, nn.Linear) and m.bias is not None: 220 | nn.init.constant_(m.bias, 0) 221 | elif isinstance(m, nn.Conv2d): 222 | # NOTE conv was left to pytorch default in my original init 223 | lecun_normal_(m.weight) 224 | if m.bias is not None: 225 | nn.init.zeros_(m.bias) 226 | elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 227 | nn.init.zeros_(m.bias) 228 | nn.init.ones_(m.weight) 229 | 230 | 231 | class VisionMamba(nn.Module): 232 | def __init__(self, 233 | img_size=224, 234 | patch_size=16, 235 | stride=16, 236 | depth=24, 237 | embed_dim=192, 238 | channels=3, 239 | num_classes=1000, 240 | ssm_cfg=None, 241 | drop_rate=0., 242 | drop_path_rate=0.1, 243 | norm_epsilon: float = 1e-5, 244 | rms_norm: bool = False, 245 | initializer_cfg=None, 246 | fused_add_norm=False, 247 | residual_in_fp32=False, 248 | device=None, 249 | dtype=None, 250 | ft_seq_len=None, 251 | pt_hw_seq_len=14, 252 | if_bidirectional=False, 253 | final_pool_type='none', 254 | if_abs_pos_embed=False, 255 | if_rope=False, 256 | if_rope_residual=False, 257 | flip_img_sequences_ratio=-1., 258 | if_bimamba=False, 259 | bimamba_type="none", 260 | if_cls_token=False, 261 | if_devide_out=False, 262 | init_layer_scale=None, 263 | use_double_cls_token=False, 264 | use_middle_cls_token=False, 265 | START_flag=0, 266 | **kwargs): 267 | factory_kwargs = {"device": device, "dtype": dtype} 268 | # add factory_kwargs into kwargs 269 | kwargs.update(factory_kwargs) 270 | super().__init__() 271 | self.residual_in_fp32 = residual_in_fp32 # True 272 | self.fused_add_norm = fused_add_norm # True 273 | self.if_bidirectional = if_bidirectional # False 274 | self.final_pool_type = final_pool_type # 'mean' 275 | self.if_abs_pos_embed = if_abs_pos_embed # True 276 | self.if_rope = if_rope # False 277 | self.if_rope_residual = if_rope_residual # False 278 | self.flip_img_sequences_ratio = flip_img_sequences_ratio 279 | self.if_cls_token = if_cls_token # True 280 | self.use_double_cls_token = use_double_cls_token # False 281 | self.use_middle_cls_token = use_middle_cls_token # True 282 | self.num_tokens = 1 if if_cls_token else 0 # 1 283 | 284 | # rms_norm=True, bimamba_type="v2", if_devide_out=True, 285 | 286 | 287 | # pretrain parameters 288 | self.num_classes = num_classes 289 | self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models, 192 / 384 290 | 291 | # patch_size = 16, stride = 8 / 16, channels=3, embed_dim=192 / 384 292 | self.patch_embed = PatchEmbed( 293 | img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim) 294 | num_patches = self.patch_embed.num_patches 295 | 296 | if if_cls_token: 297 | if use_double_cls_token: 298 | self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 299 | self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 300 | self.num_tokens = 2 301 | else: 302 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 303 | # self.num_tokens = 1 304 | 305 | if if_abs_pos_embed: 306 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim)) 307 | self.pos_drop = nn.Dropout(p=drop_rate) 308 | 309 | if if_rope: 310 | half_head_dim = embed_dim // 2 311 | hw_seq_len = img_size // patch_size 312 | self.rope = VisionRotaryEmbeddingFast( 313 | dim=half_head_dim, 314 | pt_seq_len=pt_hw_seq_len, 315 | ft_seq_len=hw_seq_len 316 | ) 317 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 318 | 319 | 320 | # TODO: release this comment 321 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 322 | # import ipdb;ipdb.set_trace() 323 | inter_dpr = [0.0] + dpr 324 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 325 | # transformer blocks 326 | self.layers = nn.ModuleList( 327 | [ 328 | create_block( 329 | embed_dim, 330 | ssm_cfg=ssm_cfg, 331 | norm_epsilon=norm_epsilon, 332 | rms_norm=rms_norm, 333 | residual_in_fp32=residual_in_fp32, 334 | fused_add_norm=fused_add_norm, 335 | layer_idx=i, 336 | if_bimamba=if_bimamba, 337 | bimamba_type=bimamba_type, 338 | drop_path=inter_dpr[i], 339 | if_devide_out=if_devide_out, 340 | init_layer_scale=init_layer_scale, 341 | START_flag=START_flag, 342 | **factory_kwargs, 343 | ) 344 | for i in range(depth) 345 | ] 346 | ) 347 | 348 | # output head 349 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 350 | embed_dim, eps=norm_epsilon, **factory_kwargs 351 | ) 352 | 353 | # self.pre_logits = nn.Identity() 354 | 355 | # original init 356 | self.patch_embed.apply(segm_init_weights) 357 | self.head.apply(segm_init_weights) 358 | if if_abs_pos_embed: 359 | trunc_normal_(self.pos_embed, std=.02) 360 | if if_cls_token: 361 | if use_double_cls_token: 362 | trunc_normal_(self.cls_token_head, std=.02) 363 | trunc_normal_(self.cls_token_tail, std=.02) 364 | else: 365 | trunc_normal_(self.cls_token, std=.02) 366 | 367 | # mamba init 368 | self.apply( 369 | partial( 370 | _init_weights, 371 | n_layer=depth, 372 | **(initializer_cfg if initializer_cfg is not None else {}), 373 | ) 374 | ) 375 | 376 | 377 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 378 | return { 379 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 380 | for i, layer in enumerate(self.layers) 381 | } 382 | 383 | @torch.jit.ignore 384 | def no_weight_decay(self): 385 | return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"} 386 | 387 | @torch.jit.ignore() 388 | def load_pretrained(self, checkpoint_path, prefix=""): 389 | _load_weights(self, checkpoint_path, prefix) 390 | 391 | def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): 392 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 393 | # with slight modifications to add the dist_token 394 | x = self.patch_embed(x) 395 | B, M, _ = x.shape 396 | 397 | if self.if_cls_token: 398 | if self.use_double_cls_token: 399 | cls_token_head = self.cls_token_head.expand(B, -1, -1) 400 | cls_token_tail = self.cls_token_tail.expand(B, -1, -1) 401 | token_position = [0, M + 1] 402 | x = torch.cat((cls_token_head, x, cls_token_tail), dim=1) 403 | M = x.shape[1] 404 | else: 405 | if self.use_middle_cls_token: 406 | cls_token = self.cls_token.expand(B, -1, -1) 407 | token_position = M // 2 408 | # add cls token in the middle 409 | x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 410 | elif if_random_cls_token_position: 411 | cls_token = self.cls_token.expand(B, -1, -1) 412 | token_position = random.randint(0, M) 413 | x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1) 414 | print("token_position: ", token_position) 415 | else: 416 | cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 417 | token_position = 0 418 | x = torch.cat((cls_token, x), dim=1) 419 | M = x.shape[1] # Tokens Number 420 | 421 | if self.if_abs_pos_embed: 422 | # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]: 423 | # x = x + self.pos_embed 424 | # else: 425 | # pos_embed = interpolate_pos_embed_online( 426 | # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0 427 | # ) 428 | x = x + self.pos_embed 429 | x = self.pos_drop(x) 430 | 431 | if if_random_token_rank: 432 | 433 | # 生成随机 shuffle 索引 434 | shuffle_indices = torch.randperm(M) 435 | 436 | if isinstance(token_position, list): 437 | print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 438 | else: 439 | print("original value: ", x[0, token_position, 0]) 440 | print("original token_position: ", token_position) 441 | 442 | # 执行 shuffle 443 | x = x[:, shuffle_indices, :] 444 | 445 | if isinstance(token_position, list): 446 | # 找到 cls token 在 shuffle 之后的新位置 447 | new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))] 448 | token_position = new_token_position 449 | else: 450 | # 找到 cls token 在 shuffle 之后的新位置 451 | token_position = torch.where(shuffle_indices == token_position)[0].item() 452 | 453 | if isinstance(token_position, list): 454 | print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0]) 455 | else: 456 | print("new value: ", x[0, token_position, 0]) 457 | print("new token_position: ", token_position) 458 | 459 | 460 | if_flip_img_sequences = False 461 | if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5: 462 | x = x.flip([1]) 463 | if_flip_img_sequences = True 464 | 465 | # mamba impl 466 | residual = None 467 | hidden_states = x 468 | if not self.if_bidirectional: 469 | for layer in self.layers: 470 | 471 | if if_flip_img_sequences and self.if_rope: 472 | hidden_states = hidden_states.flip([1]) 473 | if residual is not None: 474 | residual = residual.flip([1]) 475 | 476 | # rope about 477 | if self.if_rope: 478 | hidden_states = self.rope(hidden_states) 479 | if residual is not None and self.if_rope_residual: 480 | residual = self.rope(residual) 481 | 482 | if if_flip_img_sequences and self.if_rope: 483 | hidden_states = hidden_states.flip([1]) 484 | if residual is not None: 485 | residual = residual.flip([1]) 486 | 487 | hidden_states, residual = layer( 488 | hidden_states, residual, inference_params=inference_params 489 | ) 490 | else: 491 | # get two layers in a single for-loop 492 | for i in range(len(self.layers) // 2): 493 | if self.if_rope: 494 | hidden_states = self.rope(hidden_states) 495 | if residual is not None and self.if_rope_residual: 496 | residual = self.rope(residual) 497 | 498 | hidden_states_f, residual_f = self.layers[i * 2]( 499 | hidden_states, residual, inference_params=inference_params 500 | ) 501 | hidden_states_b, residual_b = self.layers[i * 2 + 1]( 502 | hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params 503 | ) 504 | hidden_states = hidden_states_f + hidden_states_b.flip([1]) 505 | residual = residual_f + residual_b.flip([1]) 506 | 507 | if not self.fused_add_norm: 508 | if residual is None: 509 | residual = hidden_states 510 | else: 511 | residual = residual + self.drop_path(hidden_states) 512 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 513 | else: 514 | # Set prenorm=False here since we don't need the residual 515 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 516 | hidden_states = fused_add_norm_fn( 517 | self.drop_path(hidden_states), 518 | self.norm_f.weight, 519 | self.norm_f.bias, 520 | eps=self.norm_f.eps, 521 | residual=residual, 522 | prenorm=False, 523 | residual_in_fp32=self.residual_in_fp32, 524 | ) 525 | 526 | # return only cls token if it exists 527 | if self.if_cls_token: 528 | if self.use_double_cls_token: 529 | return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2 530 | else: 531 | if self.use_middle_cls_token: 532 | return hidden_states[:, token_position, :] 533 | elif if_random_cls_token_position: 534 | return hidden_states[:, token_position, :] 535 | else: 536 | return hidden_states[:, token_position, :] 537 | 538 | if self.final_pool_type == 'none': 539 | return hidden_states[:, -1, :] 540 | elif self.final_pool_type == 'mean': 541 | return hidden_states.mean(dim=1) 542 | elif self.final_pool_type == 'max': 543 | return hidden_states 544 | elif self.final_pool_type == 'all': 545 | return hidden_states 546 | else: 547 | raise NotImplementedError 548 | 549 | def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False): 550 | x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank) 551 | if return_features: 552 | return x 553 | x = self.head(x) 554 | if self.final_pool_type == 'max': 555 | x = x.max(dim=1)[0] 556 | return x 557 | 558 | @register_model 559 | def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs): 560 | model = VisionMamba( 561 | patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs) 562 | model.default_cfg = _cfg() 563 | if pretrained: 564 | checkpoint = torch.hub.load_state_dict_from_url( 565 | url="to.do", 566 | map_location="cpu", check_hash=True 567 | ) 568 | model.load_state_dict(checkpoint["model"]) 569 | return model 570 | 571 | @register_model 572 | def vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs): 573 | model = VisionMamba( 574 | patch_size=16, stride=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs) 575 | model.default_cfg = _cfg() 576 | if pretrained: 577 | checkpoint = torch.hub.load_state_dict_from_url( 578 | url="to.do", 579 | map_location="cpu", check_hash=True 580 | ) 581 | model.load_state_dict(checkpoint["model"]) 582 | return model 583 | 584 | @register_model 585 | def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs): 586 | model = VisionMamba( 587 | patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs) 588 | model.default_cfg = _cfg() 589 | if pretrained: 590 | checkpoint = torch.hub.load_state_dict_from_url( 591 | url="to.do", 592 | map_location="cpu", check_hash=True 593 | ) 594 | model.load_state_dict(checkpoint["model"]) 595 | return model 596 | 597 | @register_model 598 | def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs): 599 | model = VisionMamba( 600 | patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs) 601 | model.default_cfg = _cfg() 602 | if pretrained: 603 | checkpoint = torch.hub.load_state_dict_from_url( 604 | url="to.do", 605 | map_location="cpu", check_hash=True 606 | ) 607 | model.load_state_dict(checkpoint["model"]) 608 | return model -------------------------------------------------------------------------------- /perturb_style/ALOFT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from math import sqrt 5 | import numpy as np 6 | 7 | 8 | class ALOFT(nn.Module): 9 | """ 10 | Frequency Distribution Uncertainty Module 11 | Args: 12 | p (float): probabilty of foward distribution uncertainty module, p in [0,1]. 13 | """ 14 | 15 | def __init__(self, p=0.5, eps=1e-6, mask_size=0.5, factor=0.8, mask_or_model=0): 16 | super(ALOFT, self).__init__() 17 | self.eps = eps 18 | self.p = p 19 | 20 | self.mask_size = mask_size 21 | self.factor = factor 22 | self.mask_or_model = mask_or_model 23 | 24 | def forward(self, img_fft): 25 | if (not self.training) or (np.random.random()) > self.p: 26 | return img_fft 27 | 28 | # img = img.to(torch.float32) 29 | # img_fft = torch.fft.fft2(img, dim=(2, 3), norm='ortho') 30 | B, C, h_fft, w_fft = img_fft.shape 31 | img_abs, img_pha = torch.abs(img_fft), torch.angle(img_fft) 32 | 33 | h_crop = int(h_fft * sqrt(self.mask_size)) 34 | w_crop = int(w_fft * sqrt(self.mask_size)) 35 | h_start = h_fft // 2 - h_crop // 2 36 | # w_start = w_fft // 2 - w_crop // 2 37 | w_start = 0 38 | 39 | img_abs = torch.fft.fftshift(img_abs, dim=(2, )) 40 | img_abs_ = img_abs.clone() 41 | 42 | if self.mask_or_model == 0: 43 | masks = torch.ones_like(img_abs) 44 | masks[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 0 45 | img_abs = img_abs_ * masks.cuda() 46 | freq_avg = torch.mean(img_abs_[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop], 47 | dim=(1, 2, 3), keepdim=True) # Bx1x1x1 48 | freq_avg_mask = torch.zeros_like(img_abs_) 49 | freq_avg_mask[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 1 50 | freq_avg_mask = freq_avg * freq_avg_mask.cuda() 51 | img_abs += freq_avg_mask 52 | else: 53 | var_of_elem = torch.var(img_abs_[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop], dim=0, 54 | keepdim=True) 55 | sig_of_elem = (var_of_elem + 1e-6).sqrt() # 1xHxWxC 56 | 57 | epsilon_sig = torch.randn_like( 58 | img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop]) # BxHxWxC N(0,1) 59 | gamma = epsilon_sig * sig_of_elem * self.factor 60 | 61 | img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = \ 62 | img_abs[:, :, h_start:h_start + h_crop, w_start:w_start + w_crop] + gamma 63 | 64 | img_abs = torch.fft.ifftshift(img_abs, dim=(2, )) 65 | img_stylized = img_abs * (np.e ** (1j * img_pha)) 66 | 67 | return img_stylized 68 | 69 | 70 | class ALOFT_image(nn.Module): 71 | """ 72 | Frequency Distribution Uncertainty Module 73 | Args: 74 | p (float): probabilty of foward distribution uncertainty module, p in [0,1]. 75 | """ 76 | 77 | def __init__(self, p=0.5, eps=1e-6, mask_size=0.5, factor=0.8, mask_or_model=0): 78 | super(ALOFT_image, self).__init__() 79 | self.eps = eps 80 | self.p = p 81 | 82 | self.mask_size = mask_size 83 | self.factor = factor 84 | self.mask_or_model = mask_or_model 85 | 86 | def forward(self, img): 87 | if (not self.training) or (np.random.random()) > self.p: 88 | return img 89 | 90 | # img: B K C L 91 | B, K, C, L = img.shape 92 | h = w = int(sqrt(L)) 93 | img = img.view(B, K, C, h, w) 94 | 95 | img = img.to(torch.float32) 96 | img_fft = torch.fft.rfft2(img, dim=(3, 4), norm='ortho') 97 | 98 | B, K, C, h_fft, w_fft = img_fft.shape 99 | img_abs, img_pha = torch.abs(img_fft), torch.angle(img_fft) 100 | 101 | h_crop = int(h_fft * sqrt(self.mask_size)) 102 | w_crop = int(w_fft * sqrt(self.mask_size)) 103 | h_start = h_fft // 2 - h_crop // 2 104 | # w_start = w_fft // 2 - w_crop // 2 105 | w_start = 0 106 | 107 | img_abs = torch.fft.fftshift(img_abs, dim=(3, )) 108 | img_abs_ = img_abs.clone() 109 | 110 | if self.mask_or_model == 0: 111 | masks = torch.ones_like(img_abs) 112 | masks[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 0 113 | img_abs = img_abs_ * masks.cuda() 114 | freq_avg = torch.mean(img_abs_[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop], 115 | dim=(2, 3, 4), keepdim=True) # Bx1x1x1 116 | freq_avg_mask = torch.zeros_like(img_abs_) 117 | freq_avg_mask[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = 1 118 | freq_avg_mask = freq_avg * freq_avg_mask.cuda() 119 | img_abs += freq_avg_mask 120 | else: 121 | var_of_elem = torch.var(img_abs_[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop], dim=0, 122 | keepdim=True) 123 | sig_of_elem = (var_of_elem + 1e-6).sqrt() # 1xHxWxC 124 | 125 | epsilon_sig = torch.randn_like( 126 | img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop]) # BxKxHxWxC N(0,1) 127 | gamma = epsilon_sig * sig_of_elem * self.factor 128 | 129 | img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] = \ 130 | img_abs[:, :, :, h_start:h_start + h_crop, w_start:w_start + w_crop] + gamma 131 | 132 | img_abs = torch.fft.ifftshift(img_abs, dim=(3, )) 133 | img_stylized = img_abs * (np.e ** (1j * img_pha)) 134 | 135 | img_stylized = torch.fft.irfft2(img_stylized, s=(h, w), dim=(3, 4), norm='ortho') 136 | img_stylized = img_stylized.view(B, K, C, -1) 137 | return img_stylized 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /perturb_style/DSU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | import numpy as np 5 | from torch.nn import functional as F 6 | 7 | 8 | class DSU(nn.Module): 9 | """ 10 | Distribution Uncertainty Module 11 | Args: 12 | p (float): probabilty of foward distribution uncertainty module, p in [0,1]. 13 | """ 14 | 15 | def __init__(self, p=0.5, eps=1e-6): 16 | super(DSU, self).__init__() 17 | self.eps = eps 18 | self.p = p 19 | self.factor = 1.0 20 | 21 | def _reparameterize(self, mu, std): 22 | epsilon = torch.randn_like(std) * self.factor 23 | return mu + epsilon * std 24 | 25 | def sqrtvar(self, x): 26 | t = (x.var(dim=0, keepdim=True) + self.eps).sqrt() # 1xKxCx1 27 | # t = t.repeat(x.shape[0], 1) 28 | return t 29 | 30 | def forward(self, x): 31 | # B K C L 32 | if (not self.training) or (np.random.random()) > self.p: 33 | return x 34 | 35 | # mean = x.mean(dim=[2, 3], keepdim=False) 36 | # std = (x.var(dim=[2, 3], keepdim=False) + self.eps).sqrt() 37 | 38 | mean = x.mean(dim=3, keepdim=True) # BxKxCx1 39 | std = (x.var(dim=3, keepdim=True) + self.eps).sqrt() 40 | 41 | sqrtvar_mu = self.sqrtvar(mean) 42 | sqrtvar_std = self.sqrtvar(std) 43 | 44 | beta = self._reparameterize(mean, sqrtvar_mu) 45 | gamma = self._reparameterize(std, sqrtvar_std) 46 | 47 | x = (x - mean) / std 48 | x = x * gamma + beta 49 | 50 | # x = (x - mean.reshape(x.shape[0], 1, x.shape[2])) / std.reshape(x.shape[0], 1, x.shape[2]) 51 | # x = x * gamma.reshape(x.shape[0], 1, x.shape[2]) + beta.reshape(x.shape[0], 1, x.shape[2]) 52 | 53 | return x -------------------------------------------------------------------------------- /perturb_style/MixStyle.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class MixStyle(nn.Module): 7 | """MixStyle. 8 | Reference: 9 | Zhou et al. Domain Generalization with MixStyle. ICLR 2021. 10 | """ 11 | 12 | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix='random'): 13 | """ 14 | Args: 15 | p (float): probability of using MixStyle. 16 | alpha (float): parameter of the Beta distribution. 17 | eps (float): scaling parameter to avoid numerical issues. 18 | mix (str): how to mix. 19 | """ 20 | super().__init__() 21 | self.p = p 22 | self.beta = torch.distributions.Beta(alpha, alpha) 23 | self.eps = eps 24 | self.alpha = alpha 25 | self.mix = mix 26 | self._activated = True 27 | 28 | def __repr__(self): 29 | return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})' 30 | 31 | def set_activation_status(self, status=True): 32 | self._activated = status 33 | 34 | def update_mix_method(self, mix='random'): 35 | self.mix = mix 36 | 37 | def forward(self, x): 38 | if not self.training or not self._activated: 39 | return x 40 | 41 | # BxKxCxL 42 | if random.random() > self.p: 43 | return x 44 | 45 | B = x.size(0) 46 | 47 | mu = x.mean(dim=3, keepdim=True) 48 | var = x.var(dim=3, keepdim=True) 49 | sig = (var + self.eps).sqrt() 50 | mu, sig = mu.detach(), sig.detach() 51 | x_normed = (x-mu) / sig 52 | 53 | lmda = self.beta.sample((B, 1, 1, 1)) 54 | lmda = lmda.to(x.device) 55 | 56 | if self.mix == 'random': 57 | # random shuffle 58 | perm = torch.randperm(B) 59 | 60 | elif self.mix == 'crossdomain': 61 | # split into two halves and swap the order 62 | # perm = torch.arange(B - 1, -1, -1) # inverse index 63 | # perm_b, perm_a = perm.chunk(2) 64 | # perm_b = perm_b[torch.randperm(B // 2)] 65 | # perm_a = perm_a[torch.randperm(B // 2)] 66 | # perm = torch.cat([perm_b, perm_a], 0) 67 | 68 | perm = torch.arange(B) # 0, b-1 69 | perm_a, perm_b, perm_c = perm.chunk(3) # split into three parts 70 | domain_batch_size = B // 3 71 | perm_a = perm_a[torch.randperm(domain_batch_size)] 72 | perm_b = perm_b[torch.randperm(domain_batch_size)] 73 | perm_c = perm_c[torch.randperm(domain_batch_size)] 74 | 75 | if random.random() < 0.5: 76 | perm = torch.cat([perm_b, perm_c, perm_a], dim=0) 77 | else: 78 | perm = torch.cat([perm_c, perm_a, perm_b], dim=0) 79 | 80 | else: 81 | raise NotImplementedError 82 | 83 | mu2, sig2 = mu[perm], sig[perm] 84 | mu_mix = mu*lmda + mu2 * (1-lmda) 85 | sig_mix = sig*lmda + sig2 * (1-lmda) 86 | 87 | return x_normed * sig_mix + mu_mix -------------------------------------------------------------------------------- /perturb_style/SeqTokenAug.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from math import sqrt 5 | import numpy as np 6 | 7 | 8 | class SeqTokenAug(nn.Module): 9 | """SeqTokenAug 10 | """ 11 | 12 | def __init__(self, p=0.5, aug_token_prob=0.75, batch_prob=1.0, token_attention_flag=0, seq_token_flag=0, eps=1e-6): 13 | """ 14 | Args: 15 | p (float): probability of using SeqTokenMix. 16 | eps (float): scaling parameter to avoid numerical issues. 17 | aug_token_prob (float): prob of augmented tokens. 18 | batch_prob (float): prob of augmented samples. 19 | """ 20 | super().__init__() 21 | self.p = p 22 | 23 | self.eps = eps 24 | self.aug_token_prob = aug_token_prob 25 | self.batch_prob = batch_prob 26 | 27 | self.token_attention_flag = token_attention_flag 28 | self.seq_token_flag = seq_token_flag 29 | 30 | def scores_to_mask(self, scores, mask_prob=0.75): 31 | # scores: B, K, C, N 32 | B, K, C, N = scores.shape 33 | scores_BK_C_N = scores.reshape(B*K, C, N) 34 | scores_channel_mean = torch.mean(scores_BK_C_N, dim=1, keepdim=False) # BKxN 35 | K_phase = int(N * mask_prob) # percent of the phase-related patches 36 | 37 | if mask_prob == 1.0: 38 | K_phase = N - 1 39 | 40 | threshold = torch.sort(scores_channel_mean, dim=1, descending=True)[0][:, K_phase] 41 | threshold_expand = threshold.view(B*K, 1).expand(B*K, N) 42 | mask_phase = torch.where(scores_channel_mean > threshold_expand, 43 | torch.tensor(1.).cuda(), torch.tensor(0.).cuda()) 44 | mask_phase = mask_phase.unsqueeze(dim=1).view(B, K, 1, N) 45 | return mask_phase 46 | 47 | def forward(self, x, x_aug, Bx=None): 48 | # BxKxCxL 49 | if not self.training or (random.random() > self.p): 50 | return x 51 | 52 | # x, x_aug: B, K, C, L 53 | B, K, C, L = x.shape 54 | 55 | token_mask = self.scores_to_mask(scores=Bx, mask_prob=self.aug_token_prob) 56 | x_final = x_aug * token_mask + x * (1 - token_mask) 57 | batch_mask = x.new_empty((B, 1, 1, 1)).bernoulli_(self.batch_prob).float().cuda() 58 | x_final = (batch_mask * x_final + (1 - batch_mask) * x) 59 | return x_final -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /scripts/START-M-VMamba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data='PACS' 4 | 5 | for t in `seq 0 4` 6 | do 7 | for domain in `seq 0 3` 8 | do 9 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29100 --use_env main_dg.py \ 10 | --model vmamba_tiny \ 11 | --batch-size 16 \ 12 | --data $data \ 13 | --target $domain \ 14 | --seed $t \ 15 | --data_root "/data/DataSets/" \ 16 | --lr 5e-4 \ 17 | --min-lr 1e-5 \ 18 | --warmup-lr 1e-5 \ 19 | --drop-path 0.0 \ 20 | --weight-decay 1e-8 \ 21 | --num_workers 16 \ 22 | --output_dir ./output/vmamba_t \ 23 | --epochs 50 \ 24 | --finetune /path/pretrained_models/vmamba_tiny_e292.pth \ 25 | --no_amp \ 26 | --spatial_aug_flag 1 \ 27 | --START_flag 1 \ 28 | --START_p 1.0 \ 29 | --START_token_prob 0.75 \ 30 | --SeqTokenAug_batch_prob 0.5 \ 31 | --START_attention_mode 2 32 | done 33 | done 34 | 35 | -------------------------------------------------------------------------------- /scripts/START-Vim-S.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data='PACS' 4 | 5 | for t in `seq 0 4` 6 | do 7 | for domain in `seq 0 3` 8 | do 9 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29200 --use_env main_dg.py \ 10 | --model vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 \ 11 | --batch-size 16 \ 12 | --data $data \ 13 | --target $domain \ 14 | --seed $t \ 15 | --data_root "/data/DataSets/" \ 16 | --lr 5e-6 \ 17 | --min-lr 1e-5 \ 18 | --warmup-lr 1e-5 \ 19 | --drop-path 0.0 \ 20 | --weight-decay 1e-8 \ 21 | --num_workers 16 \ 22 | --output_dir ./output/vim_small \ 23 | --epochs 50 \ 24 | --finetune /data/gjt/Mamba/Vim_for_DG/vim/pretrained_models/vim_s_midclstok_ft_81p6acc.pth \ 25 | --no_amp \ 26 | --Vim_START_flag 2 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /scripts/START-Vim-T.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data='PACS' 4 | 5 | for t in `seq 0 4` 6 | do 7 | for domain in `seq 0 3` 8 | do 9 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29100 --use_env main_dg.py \ 10 | --model vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2 \ 11 | --batch-size 16 \ 12 | --data $data \ 13 | --target $domain \ 14 | --seed $t \ 15 | --data_root "/data/DataSets/" \ 16 | --lr 5e-6 \ 17 | --min-lr 1e-5 \ 18 | --warmup-lr 1e-5 \ 19 | --drop-path 0.0 \ 20 | --weight-decay 1e-8 \ 21 | --num_workers 16 \ 22 | --output_dir ./output/vim_tiny \ 23 | --epochs 50 \ 24 | --finetune /data/gjt/Mamba/Vim_for_DG/vim/pretrained_models/vim_t_midclstok_ft_78p3acc.pth \ 25 | --no_amp \ 26 | --Vim_START_flag 2 27 | done 28 | done 29 | 30 | -------------------------------------------------------------------------------- /scripts/START-X-VMamba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data='PACS' 4 | 5 | for t in `seq 0 4` 6 | do 7 | for domain in `seq 0 3` 8 | do 9 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29100 --use_env main_dg.py \ 10 | --model vmamba_tiny \ 11 | --batch-size 16 \ 12 | --data $data \ 13 | --target $domain \ 14 | --seed $t \ 15 | --data_root "/data/DataSets/" \ 16 | --lr 5e-4 \ 17 | --min-lr 1e-5 \ 18 | --warmup-lr 1e-5 \ 19 | --drop-path 0.0 \ 20 | --weight-decay 1e-8 \ 21 | --num_workers 16 \ 22 | --output_dir ./output/vmamba_t \ 23 | --epochs 50 \ 24 | --finetune /path/pretrained_models/vmamba_tiny_e292.pth \ 25 | --no_amp \ 26 | --spatial_aug_flag 1 \ 27 | --START_flag 1 \ 28 | --START_p 1.0 \ 29 | --SeqTokenAug_token_prob 0.75 \ 30 | --SeqTokenAug_batch_prob 0.5 \ 31 | --START_attention_mode 1 32 | done 33 | done 34 | 35 | -------------------------------------------------------------------------------- /scripts/test_model_performance.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data='PACS' 4 | 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port=29000 --use_env ../main_dg.py \ 6 | --model vmamba_tiny \ 7 | --batch-size 32 \ 8 | --data $data \ 9 | --seed 0 \ 10 | --data_root "/data/DataSets/" \ 11 | --num_workers 16 \ 12 | --no_amp \ 13 | --target 3 \ 14 | --eval 1 \ 15 | --resume /data/gjt/Mamba/Vim_for_DG/vim/output/vmamba_t_SeqToken_MixStyle_Atten_CdeltaBx_P1.0_Token0.75_Batch0.5_Origin_gray_lr6.25e-05/PACS/sketch0/checkpoint.pth 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save({'state_dict_ema':checkpoint}, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | 240 | 241 | # if 'pos_embed' in state_dict: 242 | def interpolate_pos_embed(model, state_dict): 243 | pos_embed_checkpoint = state_dict['pos_embed'] 244 | embedding_size = pos_embed_checkpoint.shape[-1] 245 | num_patches = model.patch_embed.num_patches 246 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 247 | # height (== width) for the checkpoint position embedding 248 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 249 | # height (== width) for the new position embedding 250 | new_size = int(num_patches ** 0.5) 251 | # class_token and dist_token are kept unchanged 252 | if orig_size != new_size: 253 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 254 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 255 | # only the position tokens are interpolated 256 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 257 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 258 | pos_tokens = torch.nn.functional.interpolate( 259 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 260 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 261 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 262 | state_dict['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /vim_requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | alembic==1.13.0 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | blinker==1.7.0 8 | # causal-conv1d @ file:///home/zhulianghui/VisionProjects/mamba/lib/causal_conv1d-1.0.0%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=79a4bab633ebff031e615d5e8ba396b0dc0c046f4406980ee238fb86a9090038 9 | certifi==2023.11.17 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | cloudpickle==3.0.0 13 | contourpy==1.2.0 14 | cycler==0.12.1 15 | databricks-cli==0.18.0 16 | datasets==2.15.0 17 | dill==0.3.7 18 | docker==6.1.3 19 | einops==0.7.0 20 | entrypoints==0.4 21 | filelock==3.13.1 22 | Flask==3.0.0 23 | fonttools==4.46.0 24 | frozenlist==1.4.0 25 | fsspec==2023.10.0 26 | gitdb==4.0.11 27 | GitPython==3.1.40 28 | greenlet==3.0.2 29 | gunicorn==21.2.0 30 | huggingface-hub==0.19.4 31 | idna==3.6 32 | importlib-metadata==7.0.0 33 | itsdangerous==2.1.2 34 | Jinja2==3.1.2 35 | joblib==1.3.2 36 | kiwisolver==1.4.5 37 | Mako==1.3.0 38 | # mamba-ssm @ file:///home/zhulianghui/VisionProjects/mamba/lib/mamba_ssm-1.0.1%2Bcu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl#sha256=71ad1b1eafb05a6e8a41fd82e046fe85511d6378fa3a583e55215b6aa1d65ab9 39 | Markdown==3.5.1 40 | MarkupSafe==2.1.3 41 | matplotlib==3.8.2 42 | mlflow==2.9.1 43 | mmcv==1.3.8 44 | mmsegmentation==0.14.1 45 | mpmath==1.3.0 46 | multidict==6.0.4 47 | multiprocess==0.70.15 48 | networkx==3.2.1 49 | ninja==1.11.1.1 50 | numpy==1.26.2 51 | # nvidia-cublas-cu12==12.1.3.1 52 | # nvidia-cuda-cupti-cu12==12.1.105 53 | # nvidia-cuda-nvrtc-cu12==12.1.105 54 | # nvidia-cuda-runtime-cu12==12.1.105 55 | # nvidia-cudnn-cu12==8.9.2.26 56 | # nvidia-cufft-cu12==11.0.2.54 57 | # nvidia-curand-cu12==10.3.2.106 58 | # nvidia-cusolver-cu12==11.4.5.107 59 | # nvidia-cusparse-cu12==12.1.0.106 60 | # nvidia-nccl-cu12==2.18.1 61 | # nvidia-nvjitlink-cu12==12.3.101 62 | # nvidia-nvtx-cu12==12.1.105 63 | oauthlib==3.2.2 64 | opencv-python==4.8.1.78 65 | packaging==23.2 66 | pandas==2.1.3 67 | Pillow==10.1.0 68 | platformdirs==4.1.0 69 | prettytable==3.9.0 70 | protobuf==4.25.1 71 | pyarrow==14.0.1 72 | pyarrow-hotfix==0.6 73 | PyJWT==2.8.0 74 | pyparsing==3.1.1 75 | python-dateutil==2.8.2 76 | python-hostlist==1.23.0 77 | pytz==2023.3.post1 78 | PyYAML==6.0.1 79 | querystring-parser==1.2.4 80 | regex==2023.10.3 81 | requests==2.31.0 82 | safetensors==0.4.1 83 | scikit-learn==1.3.2 84 | scipy==1.11.4 85 | six==1.16.0 86 | smmap==5.0.1 87 | SQLAlchemy==2.0.23 88 | sqlparse==0.4.4 89 | sympy==1.12 90 | tabulate==0.9.0 91 | threadpoolctl==3.2.0 92 | timm==0.4.12 93 | tokenizers==0.15.0 94 | tomli==2.0.1 95 | # torch==2.1.1+cu118 96 | # torchvision==0.16.1+cu118 97 | tqdm==4.66.1 98 | transformers==4.35.2 99 | triton==2.1.0 100 | typing_extensions==4.8.0 101 | tzdata==2023.3 102 | urllib3==2.1.0 103 | wcwidth==0.2.12 104 | websocket-client==1.7.0 105 | Werkzeug==3.0.1 106 | xxhash==3.4.1 107 | yapf==0.40.2 108 | yarl==1.9.4 109 | zipp==3.17.0 110 | --------------------------------------------------------------------------------