├── .gitignore ├── README.md ├── assets ├── intro.png ├── logo.png └── poster.png ├── conf ├── aptos.json ├── bloodmnist.json ├── ham10000.json └── organcmnist.json ├── data └── ham10000.py ├── dino_variant.py ├── rein ├── __init__.py └── models │ ├── __init__.py │ └── backbones │ ├── __init__.py │ ├── beit.py │ ├── clip.py │ ├── dino_layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py │ ├── dino_v2.py │ ├── eva_02.py │ ├── reins.py │ ├── reins_dinov2.py │ ├── reins_eva_02.py │ ├── reins_resnet.py │ └── utils.py ├── requirement.txt ├── train_cufit.py ├── train_fully.py ├── train_linear.py ├── train_rein.py └── utils ├── __init__.py ├── aptos.py ├── dataset.py └── metric.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar 2 | *.pyc 3 | *.pth 4 | *.zip 5 | *.jpg 6 | 7 | data/ham10000/* 8 | data/aptos-2019/* 9 | checkpoints/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | drawing 3 |

Curriculum Fine-tuning of Vision Foundation Model for Medical Image Classification Under Label Noise

4 |

5 | Yeonguk Yu 6 | · 7 | Minhwan Ko 8 | · 9 | Sungho Shin 10 | · 11 | Kangmin Kim 12 | · 13 | Kyoobin Lee 14 |
15 | Artificial Intelligence LAB 16 | GIST, South Korea 17 |

18 |

NeurIPS 2024 - Poster Presentation

19 |

20 | 21 |

22 | 23 | Results 24 | 25 | 26 | Paper PDF 27 | 28 | 29 | Poster 30 | 31 |

32 | 33 | --- 34 |

35 |

36 | 37 |

38 | 39 | 40 | **TL;DR**: We propose CUFIT, a robust fine-tuning method for vision foundation models under noisy label conditions, based on the advantages of linear probing and adapters. 41 | 42 |
43 | 44 | Our **CU**rriculum **FI**ne-**T**uning of Vision Foundation Model **(CUFIT)** offers a robust training framework for medical multi-class image classification under noisy label conditions. 45 | Leveraging vision foundation models (VFMs) pretrained on large-scale datasets, CUFIT effectively handles noisy labels without modifying the feature extractor, using linear probing. Subsequently, it employs a curriculum fine-tuning approach, beginning with linear probing to ensure robustness to noisy samples, followed by fine-tuning two adapters for enhanced classification performance. CUFIT outperforms conventional methods across various medical image benchmarks, achieving superior results at various noise rates on datasets such as HAM10000 and APTOS-2019, highlighting its capability to address the challenges posed by noisy labels in medical datasets. 46 | 47 | 48 | ## 🚀 Getting Started 49 | ### Clone the Repository 50 | ```bash 51 | git clone https://github.com/gist-ailab/CUFIT.git 52 | cd CUFIT 53 | ``` 54 | 55 | ### Environment Setup 56 | This code is tested under Linux 20.04 and Python 3.8.18 environment, and the code requires following main packages to be installed: 57 | 58 | - [Pytorch](https://pytorch.org/): Tested under 2.0.1 version of Pytorch-GPU. 59 | - [torchvision](https://pytorch.org/vision/stable/index.html): which will be installed along Pytorch. Tested under 0.15.2 version. 60 | - [MedMNIST](https://medmnist.com/): which is needed for experiments with BloodMnist, OrgancMnist. Tested under 3.0.1 version. 61 | 62 | you may use the follwoing lines. 63 | ```bash 64 | conda create -n cufit python=3.8 65 | conda activate cufit 66 | pip install -r requirement.txt 67 | ``` 68 | 69 | 70 | ### Dataset Preparation 71 | Some public datasets are required to be downloaded for running experiments. 72 |
73 | HAM10000 preparation 74 | 75 | 1. Download the training data, training ground truth, Test data, Test ground truth of task 3 in this link. 76 | 77 | 2. Place the zip files in "CUFIT/data" folder and extract them. 78 | 79 | 3. Run the python code "ham10000.py" in "CUFIT/data". 80 | 81 | 4. This will create a folder named "ham10000" where images are sorted by its corrseponding disease. 82 |
83 | 84 |
85 | APTOS-2019 preparation 86 | 87 | 1. Download the zip files by clicking "download all" button in kaggle site. 88 | 89 | 2. Place the zip files in "CUFIT/data" folder and extract it. 90 | 91 | 3. Create a folder named "APTOS-2019" in "CUFIT/data". 92 | 93 | 4. Place the extracted files in the "APTOS-2019" folder. 94 | 95 |
96 | 97 | ### Config file may need to be changed for your path to download. For example, 98 | ~~~ 99 | # conf/ham10000.json 100 | { 101 | "epoch" : "100", 102 | "id_dataset" : "./data/ham10000", # Your path to dataset 103 | "batch_size" : 32, 104 | "save_path" : "./checkpoints/ham10000", # Your path to checkpoint 105 | "num_classes" : 7 106 | } 107 | ~~~ 108 | 109 | 110 | Place the data and create checkpoint folder following this directory structure: 111 | ```plaintext 112 | CUFIT/ 113 | ├── assets/ 114 | ├── checkpoints/ 115 | ├── HAM10000/ 116 | └── APTOS-2019/ 117 | ├── conf/ 118 | ├── HAM10000.json 119 | └── aptos.json 120 | ├── data/ 121 | ├── HAM10000/ 122 | ├── test/ 123 | └── train/ 124 | └── APTOS-2019 125 | ├── test_images/ 126 | ├── train_images/ 127 | ├── val_images/ 128 | ├── test.csv 129 | ├── train_1.csv 130 | └── valid.csv 131 | ├── rein/ 132 | └── utils/ 133 | ``` 134 | 135 | --- 136 | ## How to Run 137 | ### - To train a model by the linear probing with DINOv2-small architecture 138 | ~~~ 139 | python train_linear.py -d 'data_name' -g 'gpu-num' -n 'noise_rate' -s 'save_name' 140 | ~~~ 141 | for example, 142 | ~~~ 143 | python train_linear.py -d ham10000 -g 0 -n 0.2 -s dinov2s_linear_0.2 144 | ~~~ 145 |
146 | 147 | ### - To train a model by a single rein adapter with DINOv2-small architecture 148 | ~~~ 149 | python train_rein.py -d 'data_name' -g 'gpu-num' -n 'noise_rate -s 'save_name' 150 | ~~~ 151 | for example, 152 | ~~~ 153 | python train_rein.py -d ham10000 -g 0 -n 0.2 -s dinov2s_single_rein_0.2 154 | ~~~ 155 |
156 | 157 | ### - To train a model by the CUFIT with DINOv2-small architecture 158 | ~~~ 159 | python train_cuft.py -d 'data_name' -g 'gpu-num' -n 'noise_rate -s 'save_name' 160 | ~~~ 161 | for example, 162 | ~~~ 163 | python train_cufit.py -d ham10000 -g 0 -n 0.2 -s dinov2s_cufit_0.2 164 | ~~~ 165 |
166 | 167 | ## 🤝 Acknowledgements & Support 168 | This work waspartly supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. RS-2022-II0951, Development of Uncertainty-Aware Agents Learning by Asking Questions, 90%) and Institute of Civil Military 169 | Technology Cooperation funded by the Defense Acquisition Program Administration and Ministry of Trade, Industry and Energy of Korean government under grant No. 22-CM-GU-08, 10%. 170 | 171 | ### 🌟 License 172 | The source code of this repository is released only for academic use. See the [license](LICENSE) file for details. 173 | 174 | ### 📚 Citation 175 | If you use CUFIT in your research, please consider citing us. 176 | ```bibtex 177 | @inproceedings{ 178 | yu2024curriculum, 179 | title={Curriculum Fine-tuning of Vision Foundation Model for Medical Image Classification Under Label Noise}, 180 | author={Yeonguk Yu and Minhwan Ko and Sungho Shin and Kangmin Kim and Kyoobin Lee}, 181 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 182 | year={2024}, 183 | url={https://openreview.net/forum?id=vYUx8j5KK2} 184 | } 185 | ``` 186 | -------------------------------------------------------------------------------- /assets/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/intro.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/logo.png -------------------------------------------------------------------------------- /assets/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/poster.png -------------------------------------------------------------------------------- /conf/aptos.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch" : "100", 3 | "id_dataset" : "../data/APTOS-2019", 4 | "batch_size" : 32, 5 | "save_path" : "./checkpoints/APTOS2019/", 6 | "num_classes" : 5 7 | } -------------------------------------------------------------------------------- /conf/bloodmnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch" : "100", 3 | "id_dataset" : "./data/bloodmnist", 4 | "batch_size" : 32, 5 | "save_path" : "./checkpoints/BLOODMNIST/", 6 | "num_classes" : 8 7 | } -------------------------------------------------------------------------------- /conf/ham10000.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch" : "100", 3 | "id_dataset" : "./data/ham10000", 4 | "batch_size" : 32, 5 | "save_path" : "./checkpoints/HAM10000/", 6 | "num_classes" : 7 7 | } -------------------------------------------------------------------------------- /conf/organcmnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch" : "100", 3 | "id_dataset" : "./data/organcmnist", 4 | "batch_size" : 32, 5 | "save_path" : "./checkpoints/ORGANCMNIST/", 6 | "num_classes" : 11 7 | } -------------------------------------------------------------------------------- /data/ham10000.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def create_folder(folder_name): 5 | folder = os.path.join('ham10000', folder_name) 6 | class_list = ['MEL', 'NV','BCC','AKIEC','BKL','DF','VASC'] 7 | 8 | os.mkdir(folder) 9 | for c in class_list: 10 | folder = os.path.join('ham10000', folder_name, c) 11 | os.mkdir(folder) 12 | 13 | def read_and_move(csv_file, img_folder, is_train=True): 14 | class_list = ['MEL', 'NV','BCC','AKIEC','BKL','DF','VASC'] 15 | 16 | if is_train: 17 | dst_folder = 'ham10000/train/' 18 | else: 19 | dst_folder = 'ham10000/test/' 20 | 21 | 22 | with open(csv_file, 'r') as f: 23 | lines = f.readlines()[1:] 24 | 25 | for line in lines: 26 | items = line.split(',') 27 | items[-1] = items[-1].replace('\n', '') 28 | class_info = [float(x) for x in items[1:]] 29 | class_info = class_info.index(1.0) 30 | 31 | img = '{}.jpg'.format(items[0]) 32 | label = class_info 33 | 34 | img_path_src = os.path.join(img_folder, img) 35 | img_path_dst = os.path.join(dst_folder, class_list[label], img) 36 | 37 | shutil.copy2(img_path_src, img_path_dst) 38 | 39 | 40 | if __name__ == '__main__': 41 | train_image_folder = 'ISIC2018_Task3_Training_Input' 42 | train_image_gt_csv = 'ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv' 43 | 44 | test_image_folder = 'ISIC2018_Task3_Test_Input' 45 | test_image_gt_csv = 'ISIC2018_Task3_Test_GroundTruth/ISIC2018_Task3_Test_GroundTruth.csv' 46 | 47 | os.mkdir('./ham10000') 48 | # Create train folder 49 | create_folder('train') 50 | read_and_move(train_image_gt_csv, train_image_folder) 51 | 52 | 53 | # Create test folder 54 | create_folder('test') 55 | read_and_move(test_image_gt_csv, test_image_folder, is_train=False) 56 | -------------------------------------------------------------------------------- /dino_variant.py: -------------------------------------------------------------------------------- 1 | 2 | _small_variant = dict( 3 | patch_size=14, 4 | embed_dim=384, 5 | depth=12, 6 | num_heads=6, 7 | mlp_ratio=4, 8 | img_size=518, 9 | ffn_layer="mlp", 10 | init_values=1e-05, 11 | block_chunks=0, 12 | qkv_bias=True, 13 | proj_bias=True, 14 | ffn_bias=True 15 | ) 16 | _small_dino = 'dinov2_vits14' 17 | 18 | _base_variant = dict( 19 | patch_size=14, 20 | embed_dim=768, 21 | depth=12, 22 | num_heads=12, 23 | mlp_ratio=4, 24 | img_size=518, 25 | ffn_layer="mlp", 26 | init_values=1e-05, 27 | block_chunks=0, 28 | qkv_bias=True, 29 | proj_bias=True, 30 | ffn_bias=True, 31 | out_indices = [7, 11, 14, 17] 32 | ) 33 | _base_dino = 'dinov2_vitb14' 34 | 35 | _large_variant = dict( 36 | patch_size=14, 37 | embed_dim=1024, 38 | depth=24, 39 | num_heads=16, 40 | mlp_ratio=4, 41 | img_size=518, 42 | ffn_layer="mlp", 43 | init_values=1e-05, 44 | block_chunks=0, 45 | qkv_bias=True, 46 | proj_bias=True, 47 | ffn_bias=True 48 | ) 49 | _large_dino = 'dinov2_vitl14' 50 | -------------------------------------------------------------------------------- /rein/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | -------------------------------------------------------------------------------- /rein/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import ReinsDinoVisionTransformer, ReinsDinoVisionTransformer_3_head 2 | from .backbones import ReinsResNet 3 | -------------------------------------------------------------------------------- /rein/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dino_v2 import DinoVisionTransformer 2 | from .reins_dinov2 import ReinsDinoVisionTransformer, ReinsDinoVisionTransformer_3_head 3 | from .reins_resnet import ReinsResNet 4 | # from .reins_eva_02 import ReinsEVA2 5 | # from .clip import CLIPVisionTransformer 6 | 7 | __all__ = [ 8 | "CLIPVisionTransformer", 9 | "DinoVisionTransformer", 10 | "ReinsDinoVisionTransformer", 11 | "ReinsDinoVisionTransformer_3_head", 12 | "ReinsEVA2", 13 | ] 14 | -------------------------------------------------------------------------------- /rein/models/backbones/beit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, mmseg, setr, xcit and swin code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/fudan-zvg/SETR 10 | # https://github.com/facebookresearch/xcit/ 11 | # https://github.com/microsoft/Swin-Transformer 12 | # --------------------------------------------------------' 13 | import math 14 | from functools import partial 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.utils.checkpoint as cp 20 | # from mmseg.models.builder import BACKBONES 21 | 22 | # from mmengine.logging import MMLogger 23 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 24 | 25 | # Copyright (c) Open-MMLab. All rights reserved. 26 | import io 27 | import math 28 | import os 29 | import os.path as osp 30 | import pkgutil 31 | import time 32 | import warnings 33 | from collections import OrderedDict 34 | from importlib import import_module 35 | from tempfile import TemporaryDirectory 36 | 37 | import mmcv 38 | import numpy as np 39 | import torch 40 | import torchvision 41 | from mmengine.fileio import FileClient 42 | from mmengine.fileio import load as load_file 43 | from mmengine.dist import get_dist_info 44 | from mmengine.model import is_model_wrapper 45 | from mmengine import mkdir_or_exist 46 | from scipy import interpolate 47 | from torch.nn import functional as F 48 | from torch.optim import Optimizer 49 | from torch.utils import model_zoo 50 | 51 | ENV_MMCV_HOME = "MMCV_HOME" 52 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" 53 | DEFAULT_CACHE_DIR = "~/.cache" 54 | 55 | 56 | def _get_mmcv_home(): 57 | mmcv_home = os.path.expanduser( 58 | os.getenv( 59 | ENV_MMCV_HOME, 60 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"), 61 | ) 62 | ) 63 | 64 | mkdir_or_exist(mmcv_home) 65 | return mmcv_home 66 | 67 | 68 | def load_state_dict(module, state_dict, strict=False, logger=None): 69 | """Load state_dict to a module. 70 | 71 | This method is modified from :meth:`torch.nn.Module.load_state_dict`. 72 | Default value for ``strict`` is set to ``False`` and the message for 73 | param mismatch will be shown even if strict is False. 74 | Args: 75 | module (Module): Module that receives the state_dict. 76 | state_dict (OrderedDict): Weights. 77 | strict (bool): whether to strictly enforce that the keys 78 | in :attr:`state_dict` match the keys returned by this module's 79 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. 80 | logger (:obj:`logging.Logger`, optional): Logger to log the error 81 | message. If not specified, print function will be used. 82 | """ 83 | unexpected_keys = [] 84 | all_missing_keys = [] 85 | err_msg = [] 86 | 87 | metadata = getattr(state_dict, "_metadata", None) 88 | state_dict = state_dict.copy() 89 | if metadata is not None: 90 | state_dict._metadata = metadata 91 | 92 | # use _load_from_state_dict to enable checkpoint version control 93 | def load(module, prefix=""): 94 | # recursively check parallel module in case that the model has a 95 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 96 | if is_model_wrapper(module): 97 | module = module.module 98 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 99 | module._load_from_state_dict( 100 | state_dict, 101 | prefix, 102 | local_metadata, 103 | True, 104 | all_missing_keys, 105 | unexpected_keys, 106 | err_msg, 107 | ) 108 | for name, child in module._modules.items(): 109 | if child is not None: 110 | load(child, prefix + name + ".") 111 | 112 | load(module) 113 | load = None # break load->load reference cycle 114 | 115 | # ignore "num_batches_tracked" of BN layers 116 | missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key] 117 | 118 | if unexpected_keys: 119 | err_msg.append( 120 | "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n' 121 | ) 122 | if missing_keys: 123 | err_msg.append( 124 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n' 125 | ) 126 | 127 | rank, _ = get_dist_info() 128 | if len(err_msg) > 0 and rank == 0: 129 | err_msg.insert(0, "The model and loaded state dict do not match exactly\n") 130 | err_msg = "\n".join(err_msg) 131 | if strict: 132 | raise RuntimeError(err_msg) 133 | elif logger is not None: 134 | logger.warning(err_msg) 135 | else: 136 | print(err_msg) 137 | 138 | 139 | def load_url_dist(url, model_dir=None, map_location="cpu"): 140 | """In distributed setting, this function only download checkpoint at local 141 | rank 0.""" 142 | rank, world_size = get_dist_info() 143 | rank = int(os.environ.get("LOCAL_RANK", rank)) 144 | if rank == 0: 145 | checkpoint = model_zoo.load_url( 146 | url, model_dir=model_dir, map_location=map_location 147 | ) 148 | if world_size > 1: 149 | torch.distributed.barrier() 150 | if rank > 0: 151 | checkpoint = model_zoo.load_url( 152 | url, model_dir=model_dir, map_location=map_location 153 | ) 154 | return checkpoint 155 | 156 | 157 | def load_pavimodel_dist(model_path, map_location=None): 158 | """In distributed setting, this function only download checkpoint at local 159 | rank 0.""" 160 | try: 161 | from pavi import modelcloud 162 | except ImportError: 163 | raise ImportError("Please install pavi to load checkpoint from modelcloud.") 164 | rank, world_size = get_dist_info() 165 | rank = int(os.environ.get("LOCAL_RANK", rank)) 166 | if rank == 0: 167 | model = modelcloud.get(model_path) 168 | with TemporaryDirectory() as tmp_dir: 169 | downloaded_file = osp.join(tmp_dir, model.name) 170 | model.download(downloaded_file) 171 | checkpoint = torch.load(downloaded_file, map_location=map_location) 172 | if world_size > 1: 173 | torch.distributed.barrier() 174 | if rank > 0: 175 | model = modelcloud.get(model_path) 176 | with TemporaryDirectory() as tmp_dir: 177 | downloaded_file = osp.join(tmp_dir, model.name) 178 | model.download(downloaded_file) 179 | checkpoint = torch.load(downloaded_file, map_location=map_location) 180 | return checkpoint 181 | 182 | 183 | def load_fileclient_dist(filename, backend, map_location): 184 | """In distributed setting, this function only download checkpoint at local 185 | rank 0.""" 186 | rank, world_size = get_dist_info() 187 | rank = int(os.environ.get("LOCAL_RANK", rank)) 188 | allowed_backends = ["ceph"] 189 | if backend not in allowed_backends: 190 | raise ValueError(f"Load from Backend {backend} is not supported.") 191 | if rank == 0: 192 | fileclient = FileClient(backend=backend) 193 | buffer = io.BytesIO(fileclient.get(filename)) 194 | checkpoint = torch.load(buffer, map_location=map_location) 195 | if world_size > 1: 196 | torch.distributed.barrier() 197 | if rank > 0: 198 | fileclient = FileClient(backend=backend) 199 | buffer = io.BytesIO(fileclient.get(filename)) 200 | checkpoint = torch.load(buffer, map_location=map_location) 201 | return checkpoint 202 | 203 | 204 | def get_torchvision_models(): 205 | model_urls = dict() 206 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): 207 | if ispkg: 208 | continue 209 | _zoo = import_module(f"torchvision.models.{name}") 210 | if hasattr(_zoo, "model_urls"): 211 | _urls = getattr(_zoo, "model_urls") 212 | model_urls.update(_urls) 213 | return model_urls 214 | 215 | 216 | def get_external_models(): 217 | mmcv_home = _get_mmcv_home() 218 | default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json") 219 | default_urls = load_file(default_json_path) 220 | assert isinstance(default_urls, dict) 221 | external_json_path = osp.join(mmcv_home, "open_mmlab.json") 222 | if osp.exists(external_json_path): 223 | external_urls = load_file(external_json_path) 224 | assert isinstance(external_urls, dict) 225 | default_urls.update(external_urls) 226 | 227 | return default_urls 228 | 229 | 230 | def get_mmcls_models(): 231 | mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json") 232 | mmcls_urls = load_file(mmcls_json_path) 233 | 234 | return mmcls_urls 235 | 236 | 237 | def get_deprecated_model_names(): 238 | deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json") 239 | deprecate_urls = load_file(deprecate_json_path) 240 | assert isinstance(deprecate_urls, dict) 241 | 242 | return deprecate_urls 243 | 244 | 245 | def _process_mmcls_checkpoint(checkpoint): 246 | state_dict = checkpoint["state_dict"] 247 | new_state_dict = OrderedDict() 248 | for k, v in state_dict.items(): 249 | if k.startswith("backbone."): 250 | new_state_dict[k[9:]] = v 251 | new_checkpoint = dict(state_dict=new_state_dict) 252 | 253 | return new_checkpoint 254 | 255 | 256 | def _load_checkpoint(filename, map_location=None): 257 | """Load checkpoint from somewhere (modelzoo, file, url). 258 | 259 | Args: 260 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 261 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 262 | details. 263 | map_location (str | None): Same as :func:`torch.load`. Default: None. 264 | Returns: 265 | dict | OrderedDict: The loaded checkpoint. It can be either an 266 | OrderedDict storing model weights or a dict containing other 267 | information, which depends on the checkpoint. 268 | """ 269 | if filename.startswith("modelzoo://"): 270 | warnings.warn( 271 | 'The URL scheme of "modelzoo://" is deprecated, please ' 272 | 'use "torchvision://" instead' 273 | ) 274 | model_urls = get_torchvision_models() 275 | model_name = filename[11:] 276 | checkpoint = load_url_dist(model_urls[model_name]) 277 | elif filename.startswith("torchvision://"): 278 | model_urls = get_torchvision_models() 279 | model_name = filename[14:] 280 | checkpoint = load_url_dist(model_urls[model_name]) 281 | elif filename.startswith("open-mmlab://"): 282 | model_urls = get_external_models() 283 | model_name = filename[13:] 284 | deprecated_urls = get_deprecated_model_names() 285 | if model_name in deprecated_urls: 286 | warnings.warn( 287 | f"open-mmlab://{model_name} is deprecated in favor " 288 | f"of open-mmlab://{deprecated_urls[model_name]}" 289 | ) 290 | model_name = deprecated_urls[model_name] 291 | model_url = model_urls[model_name] 292 | # check if is url 293 | if model_url.startswith(("http://", "https://")): 294 | checkpoint = load_url_dist(model_url) 295 | else: 296 | filename = osp.join(_get_mmcv_home(), model_url) 297 | if not osp.isfile(filename): 298 | raise IOError(f"{filename} is not a checkpoint file") 299 | checkpoint = torch.load(filename, map_location=map_location) 300 | elif filename.startswith("mmcls://"): 301 | model_urls = get_mmcls_models() 302 | model_name = filename[8:] 303 | checkpoint = load_url_dist(model_urls[model_name]) 304 | checkpoint = _process_mmcls_checkpoint(checkpoint) 305 | elif filename.startswith(("http://", "https://")): 306 | checkpoint = load_url_dist(filename) 307 | elif filename.startswith("pavi://"): 308 | model_path = filename[7:] 309 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location) 310 | elif filename.startswith("s3://"): 311 | checkpoint = load_fileclient_dist( 312 | filename, backend="ceph", map_location=map_location 313 | ) 314 | else: 315 | if not osp.isfile(filename): 316 | raise IOError(f"{filename} is not a checkpoint file") 317 | checkpoint = torch.load(filename, map_location=map_location) 318 | return checkpoint 319 | 320 | 321 | def cosine_scheduler( 322 | base_value, 323 | final_value, 324 | epochs, 325 | niter_per_ep, 326 | warmup_epochs=0, 327 | start_warmup_value=0, 328 | warmup_steps=-1, 329 | ): 330 | warmup_schedule = np.array([]) 331 | warmup_iters = warmup_epochs * niter_per_ep 332 | if warmup_steps > 0: 333 | warmup_iters = warmup_steps 334 | print("Set warmup steps = %d" % warmup_iters) 335 | if warmup_epochs > 0: 336 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 337 | 338 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 339 | schedule = np.array( 340 | [ 341 | final_value 342 | + 0.5 343 | * (base_value - final_value) 344 | * (1 + math.cos(math.pi * i / (len(iters)))) 345 | for i in iters 346 | ] 347 | ) 348 | 349 | schedule = np.concatenate((warmup_schedule, schedule)) 350 | 351 | assert len(schedule) == epochs * niter_per_ep 352 | return schedule 353 | 354 | 355 | def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None): 356 | """Load checkpoint from a file or URI. 357 | 358 | Args: 359 | model (Module): Module to load checkpoint. 360 | filename (str): Accept local filepath, URL, ``torchvision://xxx``, 361 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for 362 | details. 363 | map_location (str): Same as :func:`torch.load`. 364 | strict (bool): Whether to allow different params for the model and 365 | checkpoint. 366 | logger (:mod:`logging.Logger` or None): The logger for error message. 367 | Returns: 368 | dict or OrderedDict: The loaded checkpoint. 369 | """ 370 | checkpoint = _load_checkpoint(filename, map_location) 371 | # OrderedDict is a subclass of dict 372 | if not isinstance(checkpoint, dict): 373 | raise RuntimeError(f"No state_dict found in checkpoint file {filename}") 374 | # get state_dict from checkpoint 375 | if "state_dict" in checkpoint: 376 | state_dict = checkpoint["state_dict"] 377 | elif "model" in checkpoint: 378 | state_dict = checkpoint["model"] 379 | elif "module" in checkpoint: 380 | state_dict = checkpoint["module"] 381 | else: 382 | state_dict = checkpoint 383 | # strip prefix of state_dict 384 | if list(state_dict.keys())[0].startswith("module."): 385 | state_dict = {k[7:]: v for k, v in state_dict.items()} 386 | 387 | # for MoBY, load model of online branch 388 | if sorted(list(state_dict.keys()))[0].startswith("encoder"): 389 | state_dict = { 390 | k.replace("encoder.", ""): v 391 | for k, v in state_dict.items() 392 | if k.startswith("encoder.") 393 | } 394 | 395 | # reshape absolute position embedding for Swin 396 | if state_dict.get("absolute_pos_embed") is not None: 397 | absolute_pos_embed = state_dict["absolute_pos_embed"] 398 | N1, L, C1 = absolute_pos_embed.size() 399 | N2, C2, H, W = model.absolute_pos_embed.size() 400 | if N1 != N2 or C1 != C2 or L != H * W: 401 | logger.warning("Error in loading absolute_pos_embed, pass") 402 | else: 403 | state_dict["absolute_pos_embed"] = absolute_pos_embed.view( 404 | N2, H, W, C2 405 | ).permute(0, 3, 1, 2) 406 | 407 | rank, _ = get_dist_info() 408 | if "rel_pos_bias.relative_position_bias_table" in state_dict: 409 | if rank == 0: 410 | print("Expand the shared relative position embedding to each layers. ") 411 | num_layers = model.get_num_layers() 412 | rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"] 413 | for i in range(num_layers): 414 | state_dict[ 415 | "blocks.%d.attn.relative_position_bias_table" % i 416 | ] = rel_pos_bias.clone() 417 | 418 | state_dict.pop("rel_pos_bias.relative_position_bias_table") 419 | 420 | all_keys = list(state_dict.keys()) 421 | for key in all_keys: 422 | if "relative_position_index" in key: 423 | state_dict.pop(key) 424 | 425 | if "relative_position_bias_table" in key: 426 | rel_pos_bias = state_dict[key] 427 | src_num_pos, num_attn_heads = rel_pos_bias.size() 428 | dst_num_pos, _ = model.state_dict()[key].size() 429 | dst_patch_shape = model.patch_embed.patch_shape 430 | if dst_patch_shape[0] != dst_patch_shape[1]: 431 | raise NotImplementedError() 432 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * ( 433 | dst_patch_shape[1] * 2 - 1 434 | ) 435 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 436 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 437 | if src_size != dst_size: 438 | if rank == 0: 439 | print( 440 | "Position interpolate for %s from %dx%d to %dx%d" 441 | % (key, src_size, src_size, dst_size, dst_size) 442 | ) 443 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 444 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 445 | 446 | def geometric_progression(a, r, n): 447 | return a * (1.0 - r**n) / (1.0 - r) 448 | 449 | left, right = 1.01, 1.5 450 | while right - left > 1e-6: 451 | q = (left + right) / 2.0 452 | gp = geometric_progression(1, q, src_size // 2) 453 | if gp > dst_size // 2: 454 | right = q 455 | else: 456 | left = q 457 | 458 | # if q > 1.13492: 459 | # q = 1.13492 460 | 461 | dis = [] 462 | cur = 1 463 | for i in range(src_size // 2): 464 | dis.append(cur) 465 | cur += q ** (i + 1) 466 | 467 | r_ids = [-_ for _ in reversed(dis)] 468 | 469 | x = r_ids + [0] + dis 470 | y = r_ids + [0] + dis 471 | 472 | t = dst_size // 2.0 473 | dx = np.arange(-t, t + 0.1, 1.0) 474 | dy = np.arange(-t, t + 0.1, 1.0) 475 | if rank == 0: 476 | print("x = {}".format(x)) 477 | print("dx = {}".format(dx)) 478 | 479 | all_rel_pos_bias = [] 480 | 481 | for i in range(num_attn_heads): 482 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 483 | f = interpolate.interp2d(x, y, z, kind="cubic") 484 | all_rel_pos_bias.append( 485 | torch.Tensor(f(dx, dy)) 486 | .contiguous() 487 | .view(-1, 1) 488 | .to(rel_pos_bias.device) 489 | ) 490 | 491 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 492 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 493 | state_dict[key] = new_rel_pos_bias 494 | 495 | if "pos_embed" in state_dict: 496 | pos_embed_checkpoint = state_dict["pos_embed"] 497 | embedding_size = pos_embed_checkpoint.shape[-1] 498 | num_patches = model.patch_embed.num_patches 499 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 500 | # height (== width) for the checkpoint position embedding 501 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 502 | # height (== width) for the new position embedding 503 | new_size = int(num_patches**0.5) 504 | # class_token and dist_token are kept unchanged 505 | if orig_size != new_size: 506 | if rank == 0: 507 | print( 508 | "Position interpolate from %dx%d to %dx%d" 509 | % (orig_size, orig_size, new_size, new_size) 510 | ) 511 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 512 | # only the position tokens are interpolated 513 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 514 | pos_tokens = pos_tokens.reshape( 515 | -1, orig_size, orig_size, embedding_size 516 | ).permute(0, 3, 1, 2) 517 | pos_tokens = torch.nn.functional.interpolate( 518 | pos_tokens, 519 | size=(new_size, new_size), 520 | mode="bicubic", 521 | align_corners=False, 522 | ) 523 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 524 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 525 | state_dict["pos_embed"] = new_pos_embed 526 | 527 | # interpolate position bias table if needed 528 | relative_position_bias_table_keys = [ 529 | k for k in state_dict.keys() if "relative_position_bias_table" in k 530 | ] 531 | for table_key in relative_position_bias_table_keys: 532 | table_pretrained = state_dict[table_key] 533 | table_current = model.state_dict()[table_key] 534 | L1, nH1 = table_pretrained.size() 535 | L2, nH2 = table_current.size() 536 | if nH1 != nH2: 537 | logger.warning(f"Error in loading {table_key}, pass") 538 | else: 539 | if L1 != L2: 540 | S1 = int(L1**0.5) 541 | S2 = int(L2**0.5) 542 | table_pretrained_resized = F.interpolate( 543 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1), 544 | size=(S2, S2), 545 | mode="bicubic", 546 | ) 547 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute( 548 | 1, 0 549 | ) 550 | 551 | # load state_dict 552 | load_state_dict(model, state_dict, strict, logger) 553 | return checkpoint 554 | 555 | 556 | def weights_to_cpu(state_dict): 557 | """Copy a model state_dict to cpu. 558 | 559 | Args: 560 | state_dict (OrderedDict): Model weights on GPU. 561 | Returns: 562 | OrderedDict: Model weights on GPU. 563 | """ 564 | state_dict_cpu = OrderedDict() 565 | for key, val in state_dict.items(): 566 | state_dict_cpu[key] = val.cpu() 567 | return state_dict_cpu 568 | 569 | 570 | def _save_to_state_dict(module, destination, prefix, keep_vars): 571 | """Saves module state to `destination` dictionary. 572 | 573 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. 574 | Args: 575 | module (nn.Module): The module to generate state_dict. 576 | destination (dict): A dict where state will be stored. 577 | prefix (str): The prefix for parameters and buffers used in this 578 | module. 579 | """ 580 | for name, param in module._parameters.items(): 581 | if param is not None: 582 | destination[prefix + name] = param if keep_vars else param.detach() 583 | for name, buf in module._buffers.items(): 584 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d 585 | if buf is not None: 586 | destination[prefix + name] = buf if keep_vars else buf.detach() 587 | 588 | 589 | def get_state_dict(module, destination=None, prefix="", keep_vars=False): 590 | """Returns a dictionary containing a whole state of the module. 591 | 592 | Both parameters and persistent buffers (e.g. running averages) are 593 | included. Keys are corresponding parameter and buffer names. 594 | This method is modified from :meth:`torch.nn.Module.state_dict` to 595 | recursively check parallel module in case that the model has a complicated 596 | structure, e.g., nn.Module(nn.Module(DDP)). 597 | Args: 598 | module (nn.Module): The module to generate state_dict. 599 | destination (OrderedDict): Returned dict for the state of the 600 | module. 601 | prefix (str): Prefix of the key. 602 | keep_vars (bool): Whether to keep the variable property of the 603 | parameters. Default: False. 604 | Returns: 605 | dict: A dictionary containing a whole state of the module. 606 | """ 607 | # recursively check parallel module in case that the model has a 608 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 609 | if is_model_wrapper(module): 610 | module = module.module 611 | 612 | # below is the same as torch.nn.Module.state_dict() 613 | if destination is None: 614 | destination = OrderedDict() 615 | destination._metadata = OrderedDict() 616 | destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version) 617 | _save_to_state_dict(module, destination, prefix, keep_vars) 618 | for name, child in module._modules.items(): 619 | if child is not None: 620 | get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars) 621 | for hook in module._state_dict_hooks.values(): 622 | hook_result = hook(module, destination, prefix, local_metadata) 623 | if hook_result is not None: 624 | destination = hook_result 625 | return destination 626 | 627 | 628 | def save_checkpoint(model, filename, optimizer=None, meta=None): 629 | """Save checkpoint to file. 630 | 631 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and 632 | ``optimizer``. By default ``meta`` will contain version and time info. 633 | Args: 634 | model (Module): Module whose params are to be saved. 635 | filename (str): Checkpoint filename. 636 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 637 | meta (dict, optional): Metadata to be saved in checkpoint. 638 | """ 639 | if meta is None: 640 | meta = {} 641 | elif not isinstance(meta, dict): 642 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}") 643 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 644 | 645 | if is_model_wrapper(model): 646 | model = model.module 647 | 648 | if hasattr(model, "CLASSES") and model.CLASSES is not None: 649 | # save class name to the meta 650 | meta.update(CLASSES=model.CLASSES) 651 | 652 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))} 653 | # save optimizer state dict in the checkpoint 654 | if isinstance(optimizer, Optimizer): 655 | checkpoint["optimizer"] = optimizer.state_dict() 656 | elif isinstance(optimizer, dict): 657 | checkpoint["optimizer"] = {} 658 | for name, optim in optimizer.items(): 659 | checkpoint["optimizer"][name] = optim.state_dict() 660 | 661 | if filename.startswith("pavi://"): 662 | try: 663 | from pavi import modelcloud 664 | from pavi.exception import NodeNotFoundError 665 | except ImportError: 666 | raise ImportError("Please install pavi to load checkpoint from modelcloud.") 667 | model_path = filename[7:] 668 | root = modelcloud.Folder() 669 | model_dir, model_name = osp.split(model_path) 670 | try: 671 | model = modelcloud.get(model_dir) 672 | except NodeNotFoundError: 673 | model = root.create_training_model(model_dir) 674 | with TemporaryDirectory() as tmp_dir: 675 | checkpoint_file = osp.join(tmp_dir, model_name) 676 | with open(checkpoint_file, "wb") as f: 677 | torch.save(checkpoint, f) 678 | f.flush() 679 | model.create_file(checkpoint_file, name=model_name) 680 | else: 681 | mmcv.mkdir_or_exist(osp.dirname(filename)) 682 | # immediately flush buffer 683 | with open(filename, "wb") as f: 684 | torch.save(checkpoint, f) 685 | f.flush() 686 | 687 | 688 | class DropPath(nn.Module): 689 | """Drop paths (Stochastic Depth) per sample (when applied in main path of 690 | residual blocks).""" 691 | 692 | def __init__(self, drop_prob=None): 693 | super(DropPath, self).__init__() 694 | self.drop_prob = drop_prob 695 | 696 | def forward(self, x): 697 | return drop_path(x, self.drop_prob, self.training) 698 | 699 | def extra_repr(self) -> str: 700 | return "p={}".format(self.drop_prob) 701 | 702 | 703 | class Mlp(nn.Module): 704 | def __init__( 705 | self, 706 | in_features, 707 | hidden_features=None, 708 | out_features=None, 709 | act_layer=nn.GELU, 710 | drop=0.0, 711 | ): 712 | super().__init__() 713 | out_features = out_features or in_features 714 | hidden_features = hidden_features or in_features 715 | self.fc1 = nn.Linear(in_features, hidden_features) 716 | self.act = act_layer() 717 | self.fc2 = nn.Linear(hidden_features, out_features) 718 | self.drop = nn.Dropout(drop) 719 | 720 | def forward(self, x): 721 | x = self.fc1(x) 722 | x = self.act(x) 723 | # x = self.drop(x) 724 | # commit this for the original BERT implement 725 | x = self.fc2(x) 726 | x = self.drop(x) 727 | return x 728 | 729 | 730 | class Attention(nn.Module): 731 | def __init__( 732 | self, 733 | dim, 734 | num_heads=8, 735 | qkv_bias=False, 736 | qk_scale=None, 737 | attn_drop=0.0, 738 | proj_drop=0.0, 739 | window_size=None, 740 | attn_head_dim=None, 741 | ): 742 | super().__init__() 743 | self.num_heads = num_heads 744 | head_dim = dim // num_heads 745 | if attn_head_dim is not None: 746 | head_dim = attn_head_dim 747 | all_head_dim = head_dim * self.num_heads 748 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 749 | self.scale = qk_scale or head_dim**-0.5 750 | 751 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 752 | if qkv_bias: 753 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 754 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 755 | else: 756 | self.q_bias = None 757 | self.v_bias = None 758 | 759 | if window_size: 760 | self.window_size = window_size 761 | self.num_relative_distance = (2 * window_size[0] - 1) * ( 762 | 2 * window_size[1] - 1 763 | ) + 3 764 | self.relative_position_bias_table = nn.Parameter( 765 | torch.zeros(self.num_relative_distance, num_heads) 766 | ) # 2*Wh-1 * 2*Ww-1, nH 767 | # cls to token & token 2 cls & cls to cls 768 | 769 | # get pair-wise relative position index for each token inside the window 770 | coords_h = torch.arange(window_size[0]) 771 | coords_w = torch.arange(window_size[1]) 772 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 773 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 774 | relative_coords = ( 775 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 776 | ) # 2, Wh*Ww, Wh*Ww 777 | relative_coords = relative_coords.permute( 778 | 1, 2, 0 779 | ).contiguous() # Wh*Ww, Wh*Ww, 2 780 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 781 | relative_coords[:, :, 1] += window_size[1] - 1 782 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 783 | relative_position_index = torch.zeros( 784 | size=(window_size[0] * window_size[1] + 1,) * 2, 785 | dtype=relative_coords.dtype, 786 | ) 787 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 788 | relative_position_index[0, 0:] = self.num_relative_distance - 3 789 | relative_position_index[0:, 0] = self.num_relative_distance - 2 790 | relative_position_index[0, 0] = self.num_relative_distance - 1 791 | self.register_buffer("relative_position_index", relative_position_index) 792 | 793 | # trunc_normal_(self.relative_position_bias_table, std=.0) 794 | else: 795 | self.window_size = None 796 | self.relative_position_bias_table = None 797 | self.relative_position_index = None 798 | 799 | self.attn_drop = nn.Dropout(attn_drop) 800 | self.proj = nn.Linear(all_head_dim, dim) 801 | self.proj_drop = nn.Dropout(proj_drop) 802 | 803 | def forward(self, x, rel_pos_bias=None): 804 | B, N, C = x.shape 805 | qkv_bias = None 806 | if self.q_bias is not None: 807 | qkv_bias = torch.cat( 808 | ( 809 | self.q_bias, 810 | torch.zeros_like(self.v_bias, requires_grad=False), 811 | self.v_bias, 812 | ) 813 | ) 814 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 815 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 816 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 817 | #qkv: B,N,3,K,C->3,B,K,N,C 818 | q, k, v = ( 819 | qkv[0], 820 | qkv[1], 821 | qkv[2], 822 | ) # make torchscript happy (cannot use tensor as tuple) 823 | 824 | q = q * self.scale 825 | attn = q @ k.transpose(-2, -1) 826 | # attn : B,K,N,C@B,K,N,C->B,K,N,N 827 | 828 | if self.relative_position_bias_table is not None: 829 | relative_position_bias = self.relative_position_bias_table[ 830 | self.relative_position_index.view(-1) 831 | ].view( 832 | self.window_size[0] * self.window_size[1] + 1, 833 | self.window_size[0] * self.window_size[1] + 1, 834 | -1, 835 | ) # Wh*Ww,Wh*Ww,nH 836 | relative_position_bias = relative_position_bias.permute( 837 | 2, 0, 1 838 | ).contiguous() # nH, Wh*Ww, Wh*Ww 839 | # relative_position_bias = relative_position_bias[:, 1:, 1:] 840 | attn = attn + relative_position_bias.unsqueeze(0) 841 | 842 | if rel_pos_bias is not None: 843 | attn = attn + rel_pos_bias 844 | 845 | attn = attn.softmax(dim=-1) 846 | attn = self.attn_drop(attn) 847 | 848 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 849 | x = self.proj(x) 850 | x = self.proj_drop(x) 851 | return x 852 | 853 | 854 | class Block(nn.Module): 855 | def __init__( 856 | self, 857 | dim, 858 | num_heads, 859 | mlp_ratio=4.0, 860 | qkv_bias=False, 861 | qk_scale=None, 862 | drop=0.0, 863 | attn_drop=0.0, 864 | drop_path=0.0, 865 | init_values=None, 866 | act_layer=nn.GELU, 867 | norm_layer=nn.LayerNorm, 868 | window_size=None, 869 | attn_head_dim=None, 870 | with_cp=False, 871 | ): 872 | super().__init__() 873 | self.with_cp = with_cp 874 | self.norm1 = norm_layer(dim) 875 | self.attn = Attention( 876 | dim, 877 | num_heads=num_heads, 878 | qkv_bias=qkv_bias, 879 | qk_scale=qk_scale, 880 | attn_drop=attn_drop, 881 | proj_drop=drop, 882 | window_size=window_size, 883 | attn_head_dim=attn_head_dim, 884 | ) 885 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 886 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 887 | self.norm2 = norm_layer(dim) 888 | mlp_hidden_dim = int(dim * mlp_ratio) 889 | self.mlp = Mlp( 890 | in_features=dim, 891 | hidden_features=mlp_hidden_dim, 892 | act_layer=act_layer, 893 | drop=drop, 894 | ) 895 | 896 | if init_values is not None: 897 | self.gamma_1 = nn.Parameter( 898 | init_values * torch.ones((dim)), requires_grad=True 899 | ) 900 | self.gamma_2 = nn.Parameter( 901 | init_values * torch.ones((dim)), requires_grad=True 902 | ) 903 | else: 904 | self.gamma_1, self.gamma_2 = None, None 905 | 906 | def forward(self, x, H, W, rel_pos_bias=None): 907 | def _inner_forward(x): 908 | if self.gamma_1 is None: 909 | x = x + self.drop_path( 910 | self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) 911 | ) 912 | x = x + self.drop_path(self.mlp(self.norm2(x))) 913 | else: 914 | x = x + self.drop_path( 915 | self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) 916 | ) 917 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 918 | return x 919 | 920 | if self.with_cp and x.requires_grad: 921 | x = cp.checkpoint(_inner_forward, x) 922 | else: 923 | x = _inner_forward(x) 924 | return x 925 | 926 | 927 | class PatchEmbed(nn.Module): 928 | """Image to Patch Embedding""" 929 | 930 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 931 | super().__init__() 932 | img_size = to_2tuple(img_size) 933 | patch_size = to_2tuple(patch_size) 934 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 935 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 936 | self.img_size = img_size 937 | self.patch_size = patch_size 938 | self.num_patches = num_patches 939 | 940 | self.proj = nn.Conv2d( 941 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 942 | ) 943 | 944 | def forward(self, x, **kwargs): 945 | B, C, H, W = x.shape 946 | # FIXME look at relaxing size constraints 947 | # assert H == self.img_size[0] and W == self.img_size[1], \ 948 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 949 | x = self.proj(x) 950 | Hp, Wp = x.shape[2], x.shape[3] 951 | 952 | x = x.flatten(2).transpose(1, 2) 953 | return x, Hp, Wp 954 | 955 | 956 | class HybridEmbed(nn.Module): 957 | """CNN Feature Map Embedding 958 | Extract feature map from CNN, flatten, project to embedding dim. 959 | """ 960 | 961 | def __init__( 962 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768 963 | ): 964 | super().__init__() 965 | assert isinstance(backbone, nn.Module) 966 | img_size = to_2tuple(img_size) 967 | self.img_size = img_size 968 | self.backbone = backbone 969 | if feature_size is None: 970 | with torch.no_grad(): 971 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 972 | # map for all networks, the feature metadata has reliable channel and stride info, but using 973 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 974 | training = backbone.training 975 | if training: 976 | backbone.eval() 977 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[ 978 | -1 979 | ] 980 | feature_size = o.shape[-2:] 981 | feature_dim = o.shape[1] 982 | backbone.train(training) 983 | else: 984 | feature_size = to_2tuple(feature_size) 985 | feature_dim = self.backbone.feature_info.channels()[-1] 986 | self.num_patches = feature_size[0] * feature_size[1] 987 | self.proj = nn.Linear(feature_dim, embed_dim) 988 | 989 | def forward(self, x): 990 | x = self.backbone(x)[-1] 991 | x = x.flatten(2).transpose(1, 2) 992 | x = self.proj(x) 993 | return x 994 | 995 | 996 | class RelativePositionBias(nn.Module): 997 | def __init__(self, window_size, num_heads): 998 | super().__init__() 999 | self.window_size = window_size 1000 | self.num_relative_distance = (2 * window_size[0] - 1) * ( 1001 | 2 * window_size[1] - 1 1002 | ) + 3 1003 | self.relative_position_bias_table = nn.Parameter( 1004 | torch.zeros(self.num_relative_distance, num_heads) 1005 | ) # 2*Wh-1 * 2*Ww-1, nH 1006 | # cls to token & token 2 cls & cls to cls 1007 | 1008 | # get pair-wise relative position index for each token inside the window 1009 | coords_h = torch.arange(window_size[0]) 1010 | coords_w = torch.arange(window_size[1]) 1011 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 1012 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 1013 | relative_coords = ( 1014 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 1015 | ) # 2, Wh*Ww, Wh*Ww 1016 | relative_coords = relative_coords.permute( 1017 | 1, 2, 0 1018 | ).contiguous() # Wh*Ww, Wh*Ww, 2 1019 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 1020 | relative_coords[:, :, 1] += window_size[1] - 1 1021 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 1022 | relative_position_index = torch.zeros( 1023 | size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype 1024 | ) 1025 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 1026 | relative_position_index[0, 0:] = self.num_relative_distance - 3 1027 | relative_position_index[0:, 0] = self.num_relative_distance - 2 1028 | relative_position_index[0, 0] = self.num_relative_distance - 1 1029 | 1030 | self.register_buffer("relative_position_index", relative_position_index) 1031 | 1032 | # trunc_normal_(self.relative_position_bias_table, std=.02) 1033 | 1034 | def forward(self): 1035 | relative_position_bias = self.relative_position_bias_table[ 1036 | self.relative_position_index.view(-1) 1037 | ].view( 1038 | self.window_size[0] * self.window_size[1] + 1, 1039 | self.window_size[0] * self.window_size[1] + 1, 1040 | -1, 1041 | ) # Wh*Ww,Wh*Ww,nH 1042 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 1043 | 1044 | 1045 | # @BACKBONES.register_module() 1046 | class BEiT(nn.Module): 1047 | """Vision Transformer with support for patch or hybrid CNN input stage""" 1048 | 1049 | def __init__( 1050 | self, 1051 | img_size=512, 1052 | patch_size=16, 1053 | in_chans=3, 1054 | num_classes=80, 1055 | embed_dim=768, 1056 | depth=12, 1057 | num_heads=12, 1058 | mlp_ratio=4.0, 1059 | qkv_bias=False, 1060 | qk_scale=None, 1061 | drop_rate=0.0, 1062 | attn_drop_rate=0.0, 1063 | drop_path_rate=0.0, 1064 | hybrid_backbone=None, 1065 | norm_layer=None, 1066 | init_values=None, 1067 | use_checkpoint=False, 1068 | use_abs_pos_emb=False, 1069 | use_rel_pos_bias=True, 1070 | use_shared_rel_pos_bias=False, 1071 | pretrained=None, 1072 | with_cp=False, 1073 | ): 1074 | super().__init__() 1075 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 1076 | self.norm_layer = norm_layer 1077 | self.num_classes = num_classes 1078 | self.num_features = ( 1079 | self.embed_dim 1080 | ) = embed_dim # num_features for consistency with other models 1081 | self.drop_path_rate = drop_path_rate 1082 | if hybrid_backbone is not None: 1083 | self.patch_embed = HybridEmbed( 1084 | hybrid_backbone, 1085 | img_size=img_size, 1086 | in_chans=in_chans, 1087 | embed_dim=embed_dim, 1088 | ) 1089 | else: 1090 | self.patch_embed = PatchEmbed( 1091 | img_size=img_size, 1092 | patch_size=patch_size, 1093 | in_chans=in_chans, 1094 | embed_dim=embed_dim, 1095 | ) 1096 | num_patches = self.patch_embed.num_patches 1097 | 1098 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 1099 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 1100 | if use_abs_pos_emb: 1101 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 1102 | else: 1103 | self.pos_embed = None 1104 | self.pos_drop = nn.Dropout(p=drop_rate) 1105 | 1106 | if use_shared_rel_pos_bias: 1107 | self.rel_pos_bias = RelativePositionBias( 1108 | window_size=self.patch_embed.patch_shape, num_heads=num_heads 1109 | ) 1110 | else: 1111 | self.rel_pos_bias = None 1112 | 1113 | dpr = [ 1114 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 1115 | ] # stochastic depth decay rule 1116 | self.use_rel_pos_bias = use_rel_pos_bias 1117 | self.use_checkpoint = use_checkpoint 1118 | self.blocks = nn.ModuleList( 1119 | [ 1120 | Block( 1121 | dim=embed_dim, 1122 | num_heads=num_heads, 1123 | mlp_ratio=mlp_ratio, 1124 | qkv_bias=qkv_bias, 1125 | qk_scale=qk_scale, 1126 | drop=drop_rate, 1127 | attn_drop=attn_drop_rate, 1128 | drop_path=dpr[i], 1129 | norm_layer=norm_layer, 1130 | with_cp=with_cp, 1131 | init_values=init_values, 1132 | window_size=self.patch_embed.patch_shape 1133 | if use_rel_pos_bias 1134 | else None, 1135 | ) 1136 | for i in range(depth) 1137 | ] 1138 | ) 1139 | 1140 | # if self.pos_embed is not None: 1141 | # trunc_normal_(self.pos_embed, std=.02) 1142 | trunc_normal_(self.cls_token, std=0.02) 1143 | self.apply(self._init_weights) 1144 | self.init_weights(pretrained) 1145 | 1146 | # self.fix_init_weight() 1147 | 1148 | def init_weights(self, pretrained=None): 1149 | """Initialize the weights in backbone. 1150 | 1151 | Args: 1152 | pretrained (str, optional): Path to pre-trained weights. 1153 | Defaults to None. 1154 | """ 1155 | # pretrained = 'pretrained/beit_large_patch16_512_pt22k_ft22kto1k.pth' 1156 | if isinstance(pretrained, str): 1157 | logger = MMLogger.get_current_instance() 1158 | load_checkpoint(self, pretrained, strict=False, logger=logger) 1159 | 1160 | def fix_init_weight(self): 1161 | def rescale(param, layer_id): 1162 | param.div_(math.sqrt(2.0 * layer_id)) 1163 | 1164 | for layer_id, layer in enumerate(self.blocks): 1165 | rescale(layer.attn.proj.weight.data, layer_id + 1) 1166 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 1167 | 1168 | def _init_weights(self, m): 1169 | if isinstance(m, nn.Linear): 1170 | trunc_normal_(m.weight, std=0.02) 1171 | if isinstance(m, nn.Linear) and m.bias is not None: 1172 | nn.init.constant_(m.bias, 0) 1173 | elif isinstance(m, nn.LayerNorm): 1174 | nn.init.constant_(m.bias, 0) 1175 | nn.init.constant_(m.weight, 1.0) 1176 | 1177 | def get_num_layers(self): 1178 | return len(self.blocks) 1179 | -------------------------------------------------------------------------------- /rein/models/backbones/clip.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from timm.models.layers import drop_path, trunc_normal_ 6 | from mmseg.models.builder import BACKBONES 7 | 8 | 9 | class LayerNorm(nn.LayerNorm): 10 | """Subclass torch's LayerNorm to handle fp16.""" 11 | 12 | def forward(self, x: torch.Tensor): 13 | orig_type = x.dtype 14 | ret = super().forward(x.type(torch.float32)) 15 | return ret.type(orig_type) 16 | 17 | 18 | class QuickGELU(nn.Module): 19 | def forward(self, x: torch.Tensor): 20 | return x * torch.sigmoid(1.702 * x) 21 | 22 | 23 | class DropPath(nn.Module): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 25 | 26 | def __init__(self, drop_prob=None): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | 30 | def forward(self, x): 31 | return drop_path(x, self.drop_prob, self.training) 32 | 33 | def extra_repr(self) -> str: 34 | return "p={}".format(self.drop_prob) 35 | 36 | 37 | class ResidualAttentionBlock(nn.Module): 38 | def __init__( 39 | self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.0 40 | ): 41 | super().__init__() 42 | 43 | self.attn = nn.MultiheadAttention(d_model, n_head) 44 | self.ln_1 = LayerNorm(d_model) 45 | self.mlp = nn.Sequential( 46 | OrderedDict( 47 | [ 48 | ("c_fc", nn.Linear(d_model, d_model * 4)), 49 | ("gelu", QuickGELU()), 50 | ("c_proj", nn.Linear(d_model * 4, d_model)), 51 | ] 52 | ) 53 | ) 54 | self.ln_2 = LayerNorm(d_model) 55 | self.attn_mask = attn_mask 56 | 57 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 58 | 59 | def attention(self, x: torch.Tensor): 60 | self.attn_mask = ( 61 | self.attn_mask.to(dtype=x.dtype, device=x.device) 62 | if self.attn_mask is not None 63 | else None 64 | ) 65 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 66 | 67 | def forward(self, x: torch.Tensor): 68 | x = x + self.drop_path(self.attention(self.ln_1(x))) 69 | x = x + self.drop_path(self.mlp(self.ln_2(x))) 70 | return x 71 | 72 | 73 | class Transformer(nn.Module): 74 | def __init__( 75 | self, 76 | width: int, 77 | layers: int, 78 | heads: int, 79 | attn_mask: torch.Tensor = None, 80 | drop_path_rate=0.0, 81 | ): 82 | super().__init__() 83 | self.width = width 84 | self.layers = layers 85 | dpr = [ 86 | x.item() for x in torch.linspace(0, drop_path_rate, layers) 87 | ] # stochastic depth decay rule 88 | self.resblocks = nn.Sequential( 89 | *[ 90 | ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) 91 | for i in range(layers) 92 | ] 93 | ) 94 | 95 | def forward(self, x: torch.Tensor): 96 | return self.resblocks(x) 97 | 98 | 99 | class Attention(nn.Module): 100 | def __init__( 101 | self, 102 | dim, 103 | num_heads=8, 104 | qkv_bias=False, 105 | qk_scale=None, 106 | attn_drop=0.0, 107 | proj_drop=0.0, 108 | ): 109 | super().__init__() 110 | self.num_heads = num_heads 111 | head_dim = dim // num_heads 112 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 113 | self.scale = qk_scale or head_dim**-0.5 114 | 115 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 116 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 117 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 118 | 119 | self.attn_drop = nn.Dropout(attn_drop) 120 | self.proj = nn.Linear(dim, dim) 121 | self.proj_drop = nn.Dropout(proj_drop) 122 | 123 | def forward(self, q, k, v): 124 | B, N, C = q.shape 125 | assert k.shape == v.shape 126 | B, M, C = k.shape 127 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) 128 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) 129 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) 130 | 131 | attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale 132 | 133 | attn = attn.softmax(dim=-1) 134 | 135 | x = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(B, N, C) 136 | 137 | x = self.proj(x) 138 | x = self.proj_drop(x) 139 | return x 140 | 141 | 142 | class TransformerDecoderLayer(nn.Module): 143 | def __init__( 144 | self, 145 | d_model, 146 | nhead, 147 | dropout=0.1, 148 | ): 149 | super().__init__() 150 | self.self_attn = Attention(d_model, nhead, proj_drop=dropout) 151 | self.cross_attn = Attention(d_model, nhead, proj_drop=dropout) 152 | 153 | self.norm1 = nn.LayerNorm(d_model) 154 | self.norm2 = nn.LayerNorm(d_model) 155 | self.norm3 = nn.LayerNorm(d_model) 156 | self.dropout = nn.Dropout(dropout) 157 | 158 | self.mlp = nn.Sequential( 159 | nn.Linear(d_model, d_model * 4), 160 | nn.GELU(), 161 | nn.Dropout(dropout), 162 | nn.Linear(d_model * 4, d_model), 163 | ) 164 | 165 | def forward(self, x, mem): 166 | q = k = v = self.norm1(x) 167 | x = x + self.self_attn(q, k, v) 168 | q = self.norm2(x) 169 | x = x + self.cross_attn(q, mem, mem) 170 | x = x + self.dropout(self.mlp(self.norm3(x))) 171 | return x 172 | 173 | 174 | @BACKBONES.register_module() 175 | class CLIPVisionTransformer(nn.Module): 176 | def __init__( 177 | self, 178 | input_resolution=224, 179 | patch_size=32, 180 | width=768, 181 | layers=12, 182 | heads=12, 183 | output_dim=512, 184 | drop_path_rate=0.0, 185 | out_indices=[3, 5, 7, 11], 186 | pretrained=None, 187 | get_embeddings=False, 188 | **kwargs, 189 | ): 190 | super().__init__() 191 | self.pretrained = pretrained 192 | self.input_resolution = input_resolution 193 | self.output_dim = output_dim 194 | self.patch_size = patch_size 195 | self.conv1 = nn.Conv2d( 196 | in_channels=3, 197 | out_channels=width, 198 | kernel_size=patch_size, 199 | stride=patch_size, 200 | bias=False, 201 | ) 202 | 203 | scale = width**-0.5 204 | self.width = width 205 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 206 | self.positional_embedding = nn.Parameter( 207 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) 208 | ) 209 | self.spatial_size = input_resolution // patch_size 210 | self.ln_pre = LayerNorm(width) 211 | self.get_embeddings = get_embeddings 212 | 213 | self.transformer = Transformer( 214 | width, layers, heads, drop_path_rate=drop_path_rate 215 | ) 216 | 217 | self.out_indices = out_indices 218 | 219 | if get_embeddings: 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | embed_dim = width 224 | 225 | def init_weights(self, pretrained=None): 226 | pretrained = pretrained or self.pretrained 227 | if isinstance(pretrained, str): 228 | checkpoint = ( 229 | torch.jit.load(pretrained, map_location="cpu").float().state_dict() 230 | ) 231 | 232 | state_dict = {} 233 | 234 | for k in checkpoint.keys(): 235 | if k.startswith("visual."): 236 | new_k = k.replace("visual.", "") 237 | state_dict[new_k] = checkpoint[k] 238 | 239 | if "positional_embedding" in state_dict.keys(): 240 | if ( 241 | self.positional_embedding.shape 242 | != state_dict["positional_embedding"].shape 243 | ): 244 | print( 245 | f'Resize the pos_embed shape from {state_dict["positional_embedding"].shape} to {self.positional_embedding.shape}' 246 | ) 247 | cls_pos = state_dict["positional_embedding"][0:1, :] 248 | leng = int(state_dict["positional_embedding"][1:,].shape[-2] ** 0.5) 249 | spatial_pos = F.interpolate( 250 | state_dict["positional_embedding"][1:,] 251 | .reshape(1, leng, leng, self.width) 252 | .permute(0, 3, 1, 2), 253 | size=(self.spatial_size, self.spatial_size), 254 | mode="bilinear", 255 | ) 256 | spatial_pos = spatial_pos.reshape( 257 | self.width, self.spatial_size * self.spatial_size 258 | ).permute(1, 0) 259 | positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0) 260 | state_dict["positional_embedding"] = positional_embedding 261 | assert ( 262 | self.positional_embedding.shape 263 | == state_dict["positional_embedding"].shape 264 | ) 265 | conv1 = state_dict["conv1.weight"] 266 | C_o, C_in, H, W = conv1.shape 267 | conv1 = torch.nn.functional.interpolate( 268 | conv1.float(), 269 | size=(self.patch_size, self.patch_size), 270 | mode="bicubic", 271 | align_corners=False, 272 | ) 273 | state_dict["conv1.weight"] = conv1 274 | 275 | u, w = self.load_state_dict(state_dict, False) 276 | print(u, w, "are misaligned params in vision transformer") 277 | 278 | def forward(self, x: torch.Tensor): 279 | x = self.conv1(x) # shape = [*, width, grid, grid] 280 | B, C, H, W = x.shape 281 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 282 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 283 | x = torch.cat( 284 | [ 285 | self.class_embedding.to(x.dtype) 286 | + torch.zeros( 287 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device 288 | ), 289 | x, 290 | ], 291 | dim=1, 292 | ) # shape = [*, grid ** 2 + 1, width] 293 | 294 | pos = self.positional_embedding.to(x.dtype) 295 | cls_pos = pos[0, :] + self.class_embedding.to(x.dtype) 296 | spatial_pos = F.interpolate( 297 | pos[1:,] 298 | .reshape(1, self.spatial_size, self.spatial_size, C) 299 | .permute(0, 3, 1, 2), 300 | size=(H, W), 301 | mode="bilinear", 302 | ) 303 | spatial_pos = spatial_pos.reshape(1, C, H * W).permute(0, 2, 1) 304 | pos = torch.cat([cls_pos.reshape(1, 1, C), spatial_pos], dim=1) 305 | x = x + pos 306 | x = self.ln_pre(x) 307 | x = x.permute(1, 0, 2) # NLD -> LND 308 | 309 | features = [] 310 | for i, blk in enumerate(self.transformer.resblocks): 311 | x = blk(x) 312 | if i in self.out_indices: 313 | xp = x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H, W) 314 | features.append(xp.contiguous()) 315 | 316 | if self.get_embeddings: 317 | x = x.permute(1, 0, 2) 318 | x = self.ln_post(x) 319 | x = x @ self.proj 320 | 321 | global_embedding = x[:, 0] 322 | visual_embedding = ( 323 | x[:, 1:].reshape(B, H, W, -1).permute(0, 3, 1, 2) 324 | ) # B C H W 325 | 326 | features.append([global_embedding, visual_embedding]) 327 | 328 | return tuple(features) 329 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock,drop_add_residual_stochastic_depth 11 | from .attention import MemEffAttention -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | # Copyright (c) Meta Platforms, Inc. and affiliates. 10 | # 11 | # This source code is licensed under the Apache License, Version 2.0 12 | # found in the LICENSE file in the root directory of this source tree. 13 | 14 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 15 | 16 | 17 | import logging 18 | import os 19 | from typing import Callable, List, Any, Tuple, Dict, Union 20 | import warnings 21 | 22 | import torch 23 | from torch import nn, Tensor 24 | 25 | from .attention import Attention, MemEffAttention 26 | from .drop_path import DropPath 27 | from .layer_scale import LayerScale 28 | from .mlp import Mlp 29 | 30 | 31 | logger = logging.getLogger("dinov2") 32 | 33 | 34 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 35 | try: 36 | if XFORMERS_ENABLED: 37 | from xformers.ops import fmha, scaled_index_add, index_select_cat 38 | 39 | XFORMERS_AVAILABLE = True 40 | warnings.warn("xFormers is available (Block)") 41 | else: 42 | warnings.warn("xFormers is disabled (Block)") 43 | raise ImportError 44 | except ImportError: 45 | XFORMERS_AVAILABLE = False 46 | 47 | warnings.warn("xFormers is not available (Block)") 48 | 49 | 50 | class Block(nn.Module): 51 | def __init__( 52 | self, 53 | dim: int, 54 | num_heads: int, 55 | mlp_ratio: float = 4.0, 56 | qkv_bias: bool = False, 57 | proj_bias: bool = True, 58 | ffn_bias: bool = True, 59 | drop: float = 0.0, 60 | attn_drop: float = 0.0, 61 | init_values=None, 62 | drop_path: float = 0.0, 63 | act_layer: Callable[..., nn.Module] = nn.GELU, 64 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 65 | attn_class: Callable[..., nn.Module] = Attention, 66 | ffn_layer: Callable[..., nn.Module] = Mlp, 67 | ) -> None: 68 | super().__init__() 69 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 70 | self.norm1 = norm_layer(dim) 71 | self.attn = attn_class( 72 | dim, 73 | num_heads=num_heads, 74 | qkv_bias=qkv_bias, 75 | proj_bias=proj_bias, 76 | attn_drop=attn_drop, 77 | proj_drop=drop, 78 | ) 79 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 80 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 81 | 82 | self.norm2 = norm_layer(dim) 83 | mlp_hidden_dim = int(dim * mlp_ratio) 84 | self.mlp = ffn_layer( 85 | in_features=dim, 86 | hidden_features=mlp_hidden_dim, 87 | act_layer=act_layer, 88 | drop=drop, 89 | bias=ffn_bias, 90 | ) 91 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 92 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 93 | 94 | self.sample_drop_ratio = drop_path 95 | 96 | def forward(self, x: Tensor) -> Tensor: 97 | def attn_residual_func(x: Tensor) -> Tensor: 98 | return self.ls1(self.attn(self.norm1(x))) 99 | 100 | def ffn_residual_func(x: Tensor) -> Tensor: 101 | return self.ls2(self.mlp(self.norm2(x))) 102 | 103 | if self.training and self.sample_drop_ratio > 0.1: 104 | # the overhead is compensated only for a drop path rate larger than 0.1 105 | x = drop_add_residual_stochastic_depth( 106 | x, 107 | residual_func=attn_residual_func, 108 | sample_drop_ratio=self.sample_drop_ratio, 109 | ) 110 | x = drop_add_residual_stochastic_depth( 111 | x, 112 | residual_func=ffn_residual_func, 113 | sample_drop_ratio=self.sample_drop_ratio, 114 | ) 115 | elif self.training and self.sample_drop_ratio > 0.0: 116 | x = x + self.drop_path1(attn_residual_func(x)) 117 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 118 | else: 119 | x = x + attn_residual_func(x) 120 | x = x + ffn_residual_func(x) 121 | return x 122 | 123 | 124 | def drop_add_residual_stochastic_depth( 125 | x: Tensor, 126 | residual_func: Callable[[Tensor], Tensor], 127 | sample_drop_ratio: float = 0.0, 128 | ) -> Tensor: 129 | # 1) extract subset using permutation 130 | b, n, d = x.shape 131 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 132 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 133 | x_subset = x[brange] 134 | 135 | # 2) apply residual_func to get residual 136 | residual = residual_func(x_subset) 137 | 138 | x_flat = x.flatten(1) 139 | residual = residual.flatten(1) 140 | 141 | residual_scale_factor = b / sample_subset_size 142 | 143 | # 3) add the residual 144 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 145 | return x_plus_residual.view_as(x) 146 | 147 | 148 | def get_branges_scales(x, sample_drop_ratio=0.0): 149 | b, n, d = x.shape 150 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 151 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 152 | residual_scale_factor = b / sample_subset_size 153 | return brange, residual_scale_factor 154 | 155 | 156 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 157 | if scaling_vector is None: 158 | x_flat = x.flatten(1) 159 | residual = residual.flatten(1) 160 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 161 | else: 162 | x_plus_residual = scaled_index_add( 163 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 164 | ) 165 | return x_plus_residual 166 | 167 | 168 | attn_bias_cache: Dict[Tuple, Any] = {} 169 | 170 | 171 | def get_attn_bias_and_cat(x_list, branges=None): 172 | """ 173 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 174 | """ 175 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 176 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 177 | if all_shapes not in attn_bias_cache.keys(): 178 | seqlens = [] 179 | for b, x in zip(batch_sizes, x_list): 180 | for _ in range(b): 181 | seqlens.append(x.shape[1]) 182 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 183 | attn_bias._batch_sizes = batch_sizes 184 | attn_bias_cache[all_shapes] = attn_bias 185 | 186 | if branges is not None: 187 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 188 | else: 189 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 190 | cat_tensors = torch.cat(tensors_bs1, dim=1) 191 | 192 | return attn_bias_cache[all_shapes], cat_tensors 193 | 194 | 195 | def drop_add_residual_stochastic_depth_list( 196 | x_list: List[Tensor], 197 | residual_func: Callable[[Tensor, Any], Tensor], 198 | sample_drop_ratio: float = 0.0, 199 | scaling_vector=None, 200 | ) -> Tensor: 201 | # 1) generate random set of indices for dropping samples in the batch 202 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 203 | branges = [s[0] for s in branges_scales] 204 | residual_scale_factors = [s[1] for s in branges_scales] 205 | 206 | # 2) get attention bias and index+concat the tensors 207 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 208 | 209 | # 3) apply residual_func to get residual, and split the result 210 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 211 | 212 | outputs = [] 213 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 214 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 215 | return outputs 216 | 217 | 218 | class NestedTensorBlock(Block): 219 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 220 | """ 221 | x_list contains a list of tensors to nest together and run 222 | """ 223 | assert isinstance(self.attn, MemEffAttention) 224 | 225 | if self.training and self.sample_drop_ratio > 0.0: 226 | 227 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 228 | return self.attn(self.norm1(x), attn_bias=attn_bias) 229 | 230 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 231 | return self.mlp(self.norm2(x)) 232 | 233 | x_list = drop_add_residual_stochastic_depth_list( 234 | x_list, 235 | residual_func=attn_residual_func, 236 | sample_drop_ratio=self.sample_drop_ratio, 237 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 238 | ) 239 | x_list = drop_add_residual_stochastic_depth_list( 240 | x_list, 241 | residual_func=ffn_residual_func, 242 | sample_drop_ratio=self.sample_drop_ratio, 243 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 244 | ) 245 | return x_list 246 | else: 247 | 248 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 249 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 250 | 251 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 252 | return self.ls2(self.mlp(self.norm2(x))) 253 | 254 | attn_bias, x = get_attn_bias_and_cat(x_list) 255 | x = x + attn_residual_func(x, attn_bias=attn_bias) 256 | x = x + ffn_residual_func(x) 257 | return attn_bias.split(x) 258 | 259 | def forward(self, x_or_x_list): 260 | if isinstance(x_or_x_list, Tensor): 261 | return super().forward(x_or_x_list) 262 | elif isinstance(x_or_x_list, list): 263 | if not XFORMERS_AVAILABLE: 264 | raise AssertionError("xFormers is required for using nested tensors") 265 | return self.forward_nested(x_or_x_list) 266 | else: 267 | raise AssertionError 268 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /rein/models/backbones/dino_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | from functools import partial 11 | import math 12 | from typing import Sequence, Tuple, Union, Callable 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.utils.checkpoint 17 | import torch.nn.functional as F 18 | from .dino_layers import ( 19 | Mlp, 20 | PatchEmbed, 21 | SwiGLUFFNFused, 22 | MemEffAttention, 23 | NestedTensorBlock as Block, 24 | ) 25 | 26 | 27 | def named_apply( 28 | fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False 29 | ) -> nn.Module: 30 | if not depth_first and include_root: 31 | fn(module=module, name=name) 32 | for child_name, child_module in module.named_children(): 33 | child_name = ".".join((name, child_name)) if name else child_name 34 | named_apply( 35 | fn=fn, 36 | module=child_module, 37 | name=child_name, 38 | depth_first=depth_first, 39 | include_root=True, 40 | ) 41 | if depth_first and include_root: 42 | fn(module=module, name=name) 43 | return module 44 | 45 | 46 | class BlockChunk(nn.ModuleList): 47 | def forward(self, x): 48 | for b in self: 49 | x = b(x) 50 | return x 51 | 52 | 53 | class DinoVisionTransformer(nn.Module): 54 | def __init__( 55 | self, 56 | img_size=224, 57 | patch_size=16, 58 | in_chans=3, 59 | embed_dim=768, 60 | depth=12, 61 | num_heads=12, 62 | mlp_ratio=4.0, 63 | qkv_bias=True, 64 | ffn_bias=True, 65 | proj_bias=True, 66 | drop_path_rate=0.0, 67 | drop_path_uniform=False, 68 | init_values=None, # for layerscale: None or 0 => no layerscale 69 | embed_layer=PatchEmbed, 70 | act_layer=nn.GELU, 71 | block_fn=partial(Block, attn_class=MemEffAttention), 72 | ffn_layer="mlp", 73 | block_chunks=1, 74 | out_indices=[7, 11, 15, 23], 75 | init_cfg=None, 76 | ): 77 | """ 78 | Args: 79 | img_size (int, tuple): input image size 80 | patch_size (int, tuple): patch size 81 | in_chans (int): number of input channels 82 | embed_dim (int): embedding dimension 83 | depth (int): depth of transformer 84 | num_heads (int): number of attention heads 85 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 86 | qkv_bias (bool): enable bias for qkv if True 87 | proj_bias (bool): enable bias for proj in attn if True 88 | ffn_bias (bool): enable bias for ffn if True 89 | drop_path_rate (float): stochastic depth rate 90 | drop_path_uniform (bool): apply uniform drop rate across blocks 91 | weight_init (str): weight init scheme 92 | init_values (float): layer-scale init values 93 | embed_layer (nn.Module): patch embedding layer 94 | act_layer (nn.Module): MLP activation layer 95 | block_fn (nn.Module): transformer block class 96 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 97 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 98 | """ 99 | super().__init__() 100 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 101 | self.out_indices = out_indices 102 | 103 | self.num_features = ( 104 | self.embed_dim 105 | ) = embed_dim # num_features for consistency with other models 106 | self.num_tokens = 1 107 | self.n_blocks = depth 108 | self.num_heads = num_heads 109 | self.patch_size = patch_size 110 | 111 | self.patch_embed = embed_layer( 112 | img_size=img_size, 113 | patch_size=patch_size, 114 | in_chans=in_chans, 115 | embed_dim=embed_dim, 116 | ) 117 | num_patches = self.patch_embed.num_patches 118 | 119 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 120 | self.pos_embed = nn.Parameter( 121 | torch.zeros(1, num_patches + self.num_tokens, embed_dim) 122 | ) 123 | 124 | if drop_path_uniform is True: 125 | dpr = [drop_path_rate] * depth 126 | else: 127 | dpr = [ 128 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 129 | ] # stochastic depth decay rule 130 | 131 | if ffn_layer == "mlp": 132 | ffn_layer = Mlp 133 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 134 | ffn_layer = SwiGLUFFNFused 135 | elif ffn_layer == "identity": 136 | 137 | def f(*args, **kwargs): 138 | return nn.Identity() 139 | 140 | ffn_layer = f 141 | else: 142 | raise NotImplementedError 143 | 144 | blocks_list = [ 145 | block_fn( 146 | dim=embed_dim, 147 | num_heads=num_heads, 148 | mlp_ratio=mlp_ratio, 149 | qkv_bias=qkv_bias, 150 | proj_bias=proj_bias, 151 | ffn_bias=ffn_bias, 152 | drop_path=dpr[i], 153 | norm_layer=norm_layer, 154 | act_layer=act_layer, 155 | ffn_layer=ffn_layer, 156 | init_values=init_values, 157 | ) 158 | for i in range(depth) 159 | ] 160 | if block_chunks > 0: 161 | self.chunked_blocks = True 162 | chunked_blocks = [] 163 | chunksize = depth // block_chunks 164 | for i in range(0, depth, chunksize): 165 | # this is to keep the block index consistent if we chunk the block list 166 | chunked_blocks.append( 167 | [nn.Identity()] * i + blocks_list[i : i + chunksize] 168 | ) 169 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 170 | else: 171 | self.chunked_blocks = False 172 | self.blocks = nn.ModuleList(blocks_list) 173 | 174 | self.norm = norm_layer(embed_dim) 175 | self.head = nn.Identity() 176 | 177 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 178 | 179 | def interpolate_pos_encoding(self, x, w, h): 180 | previous_dtype = x.dtype 181 | npatch = x.shape[1] - 1 182 | N = self.pos_embed.shape[1] - 1 183 | if npatch == N and w == h: 184 | return self.pos_embed 185 | pos_embed = self.pos_embed.float() 186 | class_pos_embed = pos_embed[:, 0] 187 | patch_pos_embed = pos_embed[:, 1:] 188 | dim = x.shape[-1] 189 | w0 = w // self.patch_size 190 | h0 = h // self.patch_size 191 | # we add a small number to avoid floating point error in the interpolation 192 | # see discussion at https://github.com/facebookresearch/dino/issues/8 193 | w0, h0 = w0 + 0.1, h0 + 0.1 194 | 195 | patch_pos_embed = nn.functional.interpolate( 196 | patch_pos_embed.reshape( 197 | 1, int(math.sqrt(N)), int(math.sqrt(N)), dim 198 | ).permute(0, 3, 1, 2), 199 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 200 | mode="bicubic", 201 | ) 202 | 203 | assert ( 204 | int(w0) == patch_pos_embed.shape[-2] 205 | and int(h0) == patch_pos_embed.shape[-1] 206 | ) 207 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 208 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( 209 | previous_dtype 210 | ) 211 | 212 | def prepare_tokens_with_masks(self, x, masks=None): 213 | B, nc, w, h = x.shape 214 | x = self.patch_embed(x) 215 | if masks is not None: 216 | x = torch.where( 217 | masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x 218 | ) 219 | 220 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 221 | x = x + self.interpolate_pos_encoding(x, w, h) 222 | 223 | return x 224 | 225 | def forward_features_list(self, x_list, masks_list): 226 | x = [ 227 | self.prepare_tokens_with_masks(x, masks) 228 | for x, masks in zip(x_list, masks_list) 229 | ] 230 | for blk in self.blocks: 231 | x = blk(x) 232 | 233 | all_x = x 234 | output = [] 235 | for x, masks in zip(all_x, masks_list): 236 | x_norm = self.norm(x) 237 | output.append( 238 | { 239 | "x_norm_clstoken": x_norm[:, 0], 240 | "x_norm_patchtokens": x_norm[:, 1:], 241 | "x_prenorm": x, 242 | "masks": masks, 243 | } 244 | ) 245 | return output 246 | 247 | def forward_features(self, x, masks=None): 248 | B, _, h, w = x.shape 249 | if isinstance(x, list): 250 | return self.forward_features_list(x, masks) 251 | 252 | x = self.prepare_tokens_with_masks(x, masks) 253 | outs = [] 254 | for idx, blk in enumerate(self.blocks): 255 | x = blk(x) 256 | if idx in self.out_indices: 257 | outs.append( 258 | x[:, 1:, :] 259 | .permute(0, 2, 1) 260 | .reshape(B, -1, h // self.patch_size, w // self.patch_size) 261 | .contiguous() 262 | ) 263 | return outs 264 | 265 | def _get_intermediate_layers_not_chunked(self, x, n=1): 266 | x = self.prepare_tokens_with_masks(x) 267 | # If n is an int, take the n last blocks. If it's a list, take them 268 | output, total_block_len = [], len(self.blocks) 269 | blocks_to_take = ( 270 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n 271 | ) 272 | for i, blk in enumerate(self.blocks): 273 | x = blk(x) 274 | if i in blocks_to_take: 275 | output.append(x) 276 | assert len(output) == len( 277 | blocks_to_take 278 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found" 279 | return output 280 | 281 | def _get_intermediate_layers_chunked(self, x, n=1): 282 | x = self.prepare_tokens_with_masks(x) 283 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 284 | # If n is an int, take the n last blocks. If it's a list, take them 285 | blocks_to_take = ( 286 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n 287 | ) 288 | for block_chunk in self.blocks: 289 | for blk in block_chunk[i:]: # Passing the nn.Identity() 290 | x = blk(x) 291 | if i in blocks_to_take: 292 | output.append(x) 293 | i += 1 294 | assert len(output) == len( 295 | blocks_to_take 296 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found" 297 | return output 298 | 299 | def get_intermediate_layers( 300 | self, 301 | x: torch.Tensor, 302 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 303 | reshape: bool = False, 304 | return_class_token: bool = False, 305 | norm=True, 306 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 307 | if self.chunked_blocks: 308 | outputs = self._get_intermediate_layers_chunked(x, n) 309 | else: 310 | outputs = self._get_intermediate_layers_not_chunked(x, n) 311 | if norm: 312 | outputs = [self.norm(out) for out in outputs] 313 | class_tokens = [out[:, 0] for out in outputs] 314 | outputs = [out[:, 1:] for out in outputs] 315 | if reshape: 316 | B, _, w, h = x.shape 317 | outputs = [ 318 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1) 319 | .permute(0, 3, 1, 2) 320 | .contiguous() 321 | for out in outputs 322 | ] 323 | if return_class_token: 324 | return tuple(zip(outputs, class_tokens)) 325 | return tuple(outputs) 326 | 327 | def forward(self, *args, **kwargs): 328 | ret = self.forward_features(*args, **kwargs) 329 | # if isinstance(ret[0], torch.Tensor): 330 | # ret[0] = F.interpolate( 331 | # ret[0], scale_factor=4, mode="bilinear", align_corners=False 332 | # ) 333 | # ret[1] = F.interpolate( 334 | # ret[1], scale_factor=2, mode="bilinear", align_corners=False 335 | # ) 336 | # ret[3] = F.interpolate( 337 | # ret[3], scale_factor=0.5, mode="bilinear", align_corners=False 338 | # ) 339 | # else: 340 | # ret[0][0] = F.interpolate( 341 | # ret[0][0], scale_factor=4, mode="bilinear", align_corners=False 342 | # ) 343 | # ret[0][1] = F.interpolate( 344 | # ret[0][1], scale_factor=2, mode="bilinear", align_corners=False 345 | # ) 346 | # ret[0][3] = F.interpolate( 347 | # ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False 348 | # ) 349 | return ret -------------------------------------------------------------------------------- /rein/models/backbones/eva_02.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Hangbo Bao 7 | # Based on timm, mmseg, setr, xcit and swin code bases 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # https://github.com/fudan-zvg/SETR 10 | # https://github.com/facebookresearch/xcit/ 11 | # https://github.com/microsoft/Swin-Transformer 12 | # --------------------------------------------------------' 13 | 14 | import torch 15 | from functools import partial 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint as checkpoint 19 | 20 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_ 21 | 22 | from .beit import load_checkpoint 23 | from mmengine.logging import MMLogger 24 | from mmseg.models.builder import BACKBONES 25 | from mmcv.cnn import build_norm_layer 26 | import xformers.ops as xops 27 | # from apex.normalization import FusedLayerNorm 28 | # from apex.normalization import FusedLayerNorm 29 | 30 | 31 | from math import pi 32 | from einops import rearrange, repeat 33 | 34 | 35 | def broadcat(tensors, dim=-1): 36 | num_tensors = len(tensors) 37 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 38 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" 39 | shape_len = list(shape_lens)[0] 40 | dim = (dim + shape_len) if dim < 0 else dim 41 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 42 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 43 | assert all( 44 | [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] 45 | ), "invalid dimensions for broadcastable concatentation" 46 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 47 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 48 | expanded_dims.insert(dim, (dim, dims[dim])) 49 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 50 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 51 | return torch.cat(tensors, dim=dim) 52 | 53 | 54 | def rotate_half(x): 55 | x = rearrange(x, "... (d r) -> ... d r", r=2) 56 | x1, x2 = x.unbind(dim=-1) 57 | x = torch.stack((-x2, x1), dim=-1) 58 | return rearrange(x, "... d r -> ... (d r)") 59 | 60 | 61 | class VisionRotaryEmbedding(nn.Module): 62 | def __init__( 63 | self, 64 | dim, 65 | pt_seq_len, 66 | ft_seq_len=None, 67 | custom_freqs=None, 68 | freqs_for="lang", 69 | theta=10000, 70 | max_freq=10, 71 | num_freqs=1, 72 | ): 73 | super().__init__() 74 | if custom_freqs: 75 | freqs = custom_freqs 76 | elif freqs_for == "lang": 77 | freqs = 1.0 / ( 78 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 79 | ) 80 | elif freqs_for == "pixel": 81 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 82 | elif freqs_for == "constant": 83 | freqs = torch.ones(num_freqs).float() 84 | else: 85 | raise ValueError(f"unknown modality {freqs_for}") 86 | 87 | if ft_seq_len is None: 88 | ft_seq_len = pt_seq_len 89 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 90 | 91 | freqs_h = torch.einsum("..., f -> ... f", t, freqs) 92 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) 93 | 94 | freqs_w = torch.einsum("..., f -> ... f", t, freqs) 95 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) 96 | 97 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) 98 | 99 | self.register_buffer("freqs_cos", freqs.cos()) 100 | self.register_buffer("freqs_sin", freqs.sin()) 101 | 102 | print("======== shape of rope freq", self.freqs_cos.shape, "========") 103 | 104 | def forward(self, t, start_index=0): 105 | rot_dim = self.freqs_cos.shape[-1] 106 | end_index = start_index + rot_dim 107 | assert ( 108 | rot_dim <= t.shape[-1] 109 | ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" 110 | t_left, t, t_right = ( 111 | t[..., :start_index], 112 | t[..., start_index:end_index], 113 | t[..., end_index:], 114 | ) 115 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 116 | return torch.cat((t_left, t, t_right), dim=-1) 117 | 118 | 119 | class VisionRotaryEmbeddingFast(nn.Module): 120 | def __init__( 121 | self, 122 | dim, 123 | pt_seq_len, 124 | ft_seq_len=None, 125 | custom_freqs=None, 126 | freqs_for="lang", 127 | theta=10000, 128 | max_freq=10, 129 | num_freqs=1, 130 | ): 131 | super().__init__() 132 | if custom_freqs: 133 | freqs = custom_freqs 134 | elif freqs_for == "lang": 135 | freqs = 1.0 / ( 136 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 137 | ) 138 | elif freqs_for == "pixel": 139 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi 140 | elif freqs_for == "constant": 141 | freqs = torch.ones(num_freqs).float() 142 | else: 143 | raise ValueError(f"unknown modality {freqs_for}") 144 | 145 | if ft_seq_len is None: 146 | ft_seq_len = pt_seq_len 147 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 148 | 149 | freqs = torch.einsum("..., f -> ... f", t, freqs) 150 | freqs = repeat(freqs, "... n -> ... (n r)", r=2) 151 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) 152 | 153 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 154 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 155 | 156 | self.register_buffer("freqs_cos", freqs_cos) 157 | self.register_buffer("freqs_sin", freqs_sin) 158 | 159 | def forward(self, t): 160 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin 161 | 162 | 163 | class DropPath(nn.Module): 164 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 165 | 166 | def __init__(self, drop_prob=None): 167 | super(DropPath, self).__init__() 168 | self.drop_prob = drop_prob 169 | 170 | def forward(self, x): 171 | return drop_path(x, self.drop_prob, self.training) 172 | 173 | def extra_repr(self) -> str: 174 | return "p={}".format(self.drop_prob) 175 | 176 | 177 | class Mlp(nn.Module): 178 | def __init__( 179 | self, 180 | in_features, 181 | hidden_features=None, 182 | out_features=None, 183 | act_layer=nn.GELU, 184 | drop=0.0, 185 | ): 186 | super().__init__() 187 | out_features = out_features or in_features 188 | hidden_features = hidden_features or in_features 189 | self.fc1 = nn.Linear(in_features, hidden_features) 190 | self.act = act_layer() 191 | self.fc2 = nn.Linear(hidden_features, out_features) 192 | self.drop = nn.Dropout(drop) 193 | 194 | def forward(self, x): 195 | x = self.fc1(x) 196 | x = self.act(x) 197 | # x = self.drop(x) 198 | # commit this for the orignal BERT implement 199 | x = self.fc2(x) 200 | x = self.drop(x) 201 | return x 202 | 203 | 204 | class SwiGLU(nn.Module): 205 | def __init__( 206 | self, 207 | in_features, 208 | hidden_features=None, 209 | out_features=None, 210 | act_layer=nn.SiLU, 211 | drop=0.0, 212 | norm_layer=nn.LayerNorm, 213 | subln=False, 214 | ): 215 | super().__init__() 216 | out_features = out_features or in_features 217 | hidden_features = hidden_features or in_features 218 | 219 | self.w1 = nn.Linear(in_features, hidden_features) 220 | self.w2 = nn.Linear(in_features, hidden_features) 221 | 222 | self.act = act_layer() 223 | if isinstance(norm_layer, dict): 224 | self.ffn_ln = ( 225 | build_norm_layer(norm_layer, hidden_features)[1] 226 | if subln 227 | else nn.Identity() 228 | ) 229 | else: 230 | self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() 231 | self.w3 = nn.Linear(hidden_features, out_features) 232 | 233 | self.drop = nn.Dropout(drop) 234 | 235 | def forward(self, x): 236 | x1 = self.w1(x) 237 | x2 = self.w2(x) 238 | hidden = self.act(x1) * x2 239 | x = self.ffn_ln(hidden) 240 | x = self.w3(x) 241 | x = self.drop(x) 242 | return x 243 | 244 | 245 | class Attention(nn.Module): 246 | def __init__( 247 | self, 248 | dim, 249 | num_heads=8, 250 | qkv_bias=False, 251 | qk_scale=None, 252 | attn_drop=0.0, 253 | proj_drop=0.0, 254 | window_size=None, 255 | attn_head_dim=None, 256 | subln=False, 257 | norm_layer=nn.LayerNorm, 258 | xattn=False, 259 | rope=None, 260 | ): 261 | super().__init__() 262 | self.num_heads = num_heads 263 | head_dim = dim // num_heads 264 | if attn_head_dim is not None: 265 | head_dim = attn_head_dim 266 | all_head_dim = head_dim * self.num_heads 267 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 268 | self.scale = qk_scale or head_dim**-0.5 269 | 270 | self.subln = subln 271 | if self.subln: 272 | self.q_proj = nn.Linear(dim, all_head_dim, bias=False) 273 | self.k_proj = nn.Linear(dim, all_head_dim, bias=False) 274 | self.v_proj = nn.Linear(dim, all_head_dim, bias=False) 275 | else: 276 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 277 | 278 | if qkv_bias: 279 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 280 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 281 | else: 282 | self.q_bias = None 283 | self.v_bias = None 284 | 285 | if window_size: 286 | self.window_size = window_size 287 | self.num_relative_distance = (2 * window_size[0] - 1) * ( 288 | 2 * window_size[1] - 1 289 | ) + 3 290 | self.relative_position_bias_table = nn.Parameter( 291 | torch.zeros(self.num_relative_distance, num_heads) 292 | ) # 2*Wh-1 * 2*Ww-1, nH 293 | # cls to token & token 2 cls & cls to cls 294 | 295 | # get pair-wise relative position index for each token inside the window 296 | coords_h = torch.arange(window_size[0]) 297 | coords_w = torch.arange(window_size[1]) 298 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 299 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 300 | relative_coords = ( 301 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 302 | ) # 2, Wh*Ww, Wh*Ww 303 | relative_coords = relative_coords.permute( 304 | 1, 2, 0 305 | ).contiguous() # Wh*Ww, Wh*Ww, 2 306 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 307 | relative_coords[:, :, 1] += window_size[1] - 1 308 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 309 | relative_position_index = torch.zeros( 310 | size=(window_size[0] * window_size[1] + 1,) * 2, 311 | dtype=relative_coords.dtype, 312 | ) 313 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 314 | relative_position_index[0, 0:] = self.num_relative_distance - 3 315 | relative_position_index[0:, 0] = self.num_relative_distance - 2 316 | relative_position_index[0, 0] = self.num_relative_distance - 1 317 | 318 | self.register_buffer("relative_position_index", relative_position_index) 319 | 320 | # trunc_normal_(self.relative_position_bias_table, std=.0) 321 | else: 322 | self.window_size = None 323 | self.relative_position_bias_table = None 324 | self.relative_position_index = None 325 | 326 | self.attn_drop = nn.Dropout(attn_drop) 327 | self.proj = nn.Linear(all_head_dim, dim) 328 | self.proj_drop = nn.Dropout(proj_drop) 329 | 330 | self.xattn = xattn 331 | self.rope = rope 332 | 333 | def forward(self, x, rel_pos_bias=None): 334 | B, N, C = x.shape 335 | 336 | if self.subln: 337 | q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) 338 | k = F.linear(input=x, weight=self.k_proj.weight, bias=None) 339 | v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) 340 | 341 | q = q.reshape(B, N, self.num_heads, -1).permute( 342 | 0, 2, 1, 3 343 | ) # B, num_heads, N, C 344 | k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) 345 | v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) 346 | else: 347 | qkv_bias = None 348 | if self.q_bias is not None: 349 | qkv_bias = torch.cat( 350 | ( 351 | self.q_bias, 352 | torch.zeros_like(self.v_bias, requires_grad=False), 353 | self.v_bias, 354 | ) 355 | ) 356 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 357 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute( 358 | 2, 0, 3, 1, 4 359 | ) # 3, B, num_heads, N, C 360 | q, k, v = qkv[0], qkv[1], qkv[2] 361 | 362 | if self.rope: 363 | q_t = q[:, :, 1:, :] 364 | ro_q_t = self.rope(q_t) 365 | q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) 366 | 367 | k_t = k[:, :, 1:, :] 368 | ro_k_t = self.rope(k_t) 369 | k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) 370 | 371 | if self.xattn: 372 | q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C 373 | k = k.permute(0, 2, 1, 3) 374 | v = v.permute(0, 2, 1, 3) 375 | 376 | x = xops.memory_efficient_attention(q, k, v) 377 | x = x.reshape(B, N, -1) 378 | x = self.proj(x) 379 | x = self.proj_drop(x) 380 | else: 381 | q = q * self.scale 382 | attn = q @ k.transpose(-2, -1) 383 | 384 | if self.relative_position_bias_table is not None: 385 | relative_position_bias = self.relative_position_bias_table[ 386 | self.relative_position_index.view(-1) 387 | ].view( 388 | self.window_size[0] * self.window_size[1] + 1, 389 | self.window_size[0] * self.window_size[1] + 1, 390 | -1, 391 | ) # Wh*Ww,Wh*Ww,nH 392 | relative_position_bias = relative_position_bias.permute( 393 | 2, 0, 1 394 | ).contiguous() # nH, Wh*Ww, Wh*Ww 395 | attn = attn + relative_position_bias.unsqueeze(0) 396 | 397 | if rel_pos_bias is not None: 398 | attn = attn + rel_pos_bias 399 | 400 | attn = attn.softmax(dim=-1) 401 | attn = self.attn_drop(attn) 402 | 403 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 404 | x = self.proj(x) 405 | x = self.proj_drop(x) 406 | 407 | return x 408 | 409 | 410 | class Block(nn.Module): 411 | def __init__( 412 | self, 413 | dim, 414 | num_heads, 415 | mlp_ratio=4.0, 416 | qkv_bias=False, 417 | qk_scale=None, 418 | drop=0.0, 419 | attn_drop=0.0, 420 | drop_path=0.0, 421 | init_values=None, 422 | act_layer=nn.GELU, 423 | norm_layer=nn.LayerNorm, 424 | window_size=None, 425 | attn_head_dim=None, 426 | subln=False, 427 | xattn=False, 428 | naiveswiglu=False, 429 | rope=None, 430 | ): 431 | super().__init__() 432 | if isinstance(norm_layer, dict): 433 | self.norm1 = build_norm_layer(norm_layer, dim)[1] 434 | else: 435 | self.norm1 = norm_layer(dim) 436 | self.attn = Attention( 437 | dim, 438 | num_heads=num_heads, 439 | qkv_bias=qkv_bias, 440 | qk_scale=qk_scale, 441 | attn_drop=attn_drop, 442 | proj_drop=drop, 443 | window_size=window_size, 444 | attn_head_dim=attn_head_dim, 445 | subln=subln, 446 | norm_layer=norm_layer, 447 | xattn=xattn, 448 | rope=rope, 449 | ) 450 | 451 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 452 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 453 | if isinstance(norm_layer, dict): 454 | self.norm2 = build_norm_layer(norm_layer, dim)[1] 455 | else: 456 | self.norm2 = norm_layer(dim) 457 | mlp_hidden_dim = int(dim * mlp_ratio) 458 | 459 | if naiveswiglu: 460 | self.mlp = SwiGLU( 461 | in_features=dim, 462 | hidden_features=mlp_hidden_dim, 463 | subln=subln, 464 | norm_layer=norm_layer, 465 | ) 466 | else: 467 | self.mlp = Mlp( 468 | in_features=dim, 469 | hidden_features=mlp_hidden_dim, 470 | act_layer=act_layer, 471 | drop=drop, 472 | ) 473 | 474 | if init_values is not None: 475 | self.gamma_1 = nn.Parameter( 476 | init_values * torch.ones((dim)), requires_grad=True 477 | ) 478 | self.gamma_2 = nn.Parameter( 479 | init_values * torch.ones((dim)), requires_grad=True 480 | ) 481 | else: 482 | self.gamma_1, self.gamma_2 = None, None 483 | 484 | def forward(self, x, rel_pos_bias=None): 485 | if self.gamma_1 is None: 486 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 487 | x = x + self.drop_path(self.mlp(self.norm2(x))) 488 | else: 489 | x = x + self.drop_path( 490 | self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) 491 | ) 492 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 493 | return x 494 | 495 | 496 | class PatchEmbed(nn.Module): 497 | """Image to Patch Embedding""" 498 | 499 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 500 | super().__init__() 501 | img_size = to_2tuple(img_size) 502 | patch_size = to_2tuple(patch_size) 503 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 504 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 505 | self.img_size = img_size 506 | self.patch_size = patch_size 507 | self.num_patches = num_patches 508 | 509 | self.proj = nn.Conv2d( 510 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 511 | ) 512 | 513 | def forward(self, x, **kwargs): 514 | B, C, H, W = x.shape 515 | # FIXME look at relaxing size constraints 516 | # assert H == self.img_size[0] and W == self.img_size[1], \ 517 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 518 | x = self.proj(x) 519 | Hp, Wp = x.shape[2], x.shape[3] 520 | 521 | x = x.flatten(2).transpose(1, 2) 522 | return x, (Hp, Wp) 523 | 524 | 525 | class HybridEmbed(nn.Module): 526 | """CNN Feature Map Embedding 527 | Extract feature map from CNN, flatten, project to embedding dim. 528 | """ 529 | 530 | def __init__( 531 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768 532 | ): 533 | super().__init__() 534 | assert isinstance(backbone, nn.Module) 535 | img_size = to_2tuple(img_size) 536 | self.img_size = img_size 537 | self.backbone = backbone 538 | if feature_size is None: 539 | with torch.no_grad(): 540 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 541 | # map for all networks, the feature metadata has reliable channel and stride info, but using 542 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 543 | training = backbone.training 544 | if training: 545 | backbone.eval() 546 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[ 547 | -1 548 | ] 549 | feature_size = o.shape[-2:] 550 | feature_dim = o.shape[1] 551 | backbone.train(training) 552 | else: 553 | feature_size = to_2tuple(feature_size) 554 | feature_dim = self.backbone.feature_info.channels()[-1] 555 | self.num_patches = feature_size[0] * feature_size[1] 556 | self.proj = nn.Linear(feature_dim, embed_dim) 557 | 558 | def forward(self, x): 559 | x = self.backbone(x)[-1] 560 | x = x.flatten(2).transpose(1, 2) 561 | x = self.proj(x) 562 | return x 563 | 564 | 565 | class RelativePositionBias(nn.Module): 566 | def __init__(self, window_size, num_heads): 567 | super().__init__() 568 | self.window_size = window_size 569 | self.num_relative_distance = (2 * window_size[0] - 1) * ( 570 | 2 * window_size[1] - 1 571 | ) + 3 572 | self.relative_position_bias_table = nn.Parameter( 573 | torch.zeros(self.num_relative_distance, num_heads) 574 | ) # 2*Wh-1 * 2*Ww-1, nH 575 | # cls to token & token 2 cls & cls to cls 576 | 577 | # get pair-wise relative position index for each token inside the window 578 | coords_h = torch.arange(window_size[0]) 579 | coords_w = torch.arange(window_size[1]) 580 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 581 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 582 | relative_coords = ( 583 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 584 | ) # 2, Wh*Ww, Wh*Ww 585 | relative_coords = relative_coords.permute( 586 | 1, 2, 0 587 | ).contiguous() # Wh*Ww, Wh*Ww, 2 588 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 589 | relative_coords[:, :, 1] += window_size[1] - 1 590 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 591 | relative_position_index = torch.zeros( 592 | size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype 593 | ) 594 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 595 | relative_position_index[0, 0:] = self.num_relative_distance - 3 596 | relative_position_index[0:, 0] = self.num_relative_distance - 2 597 | relative_position_index[0, 0] = self.num_relative_distance - 1 598 | 599 | self.register_buffer("relative_position_index", relative_position_index) 600 | 601 | # trunc_normal_(self.relative_position_bias_table, std=.02) 602 | 603 | def forward(self): 604 | relative_position_bias = self.relative_position_bias_table[ 605 | self.relative_position_index.view(-1) 606 | ].view( 607 | self.window_size[0] * self.window_size[1] + 1, 608 | self.window_size[0] * self.window_size[1] + 1, 609 | -1, 610 | ) # Wh*Ww,Wh*Ww,nH 611 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 612 | 613 | 614 | @BACKBONES.register_module() 615 | class EVA2(nn.Module): 616 | """Vision Transformer with support for patch or hybrid CNN input stage""" 617 | 618 | def __init__( 619 | self, 620 | img_size=224, 621 | patch_size=16, 622 | in_chans=3, 623 | num_classes=80, 624 | embed_dim=768, 625 | depth=12, 626 | num_heads=12, 627 | mlp_ratio=4 * 2 / 3, # GLU default 628 | qkv_bias=False, 629 | qk_scale=None, 630 | drop_rate=0.0, 631 | attn_drop_rate=0.0, 632 | drop_path_rate=0.0, 633 | hybrid_backbone=None, 634 | norm_layer=None, 635 | init_values=None, 636 | use_checkpoint=False, 637 | use_abs_pos_emb=True, 638 | use_rel_pos_bias=False, 639 | use_shared_rel_pos_bias=False, 640 | out_indices=[3, 5, 7, 11], 641 | subln=True, 642 | xattn=True, 643 | naiveswiglu=True, 644 | rope=True, 645 | pt_hw_seq_len=16, 646 | intp_freq=True, 647 | pretrained=None, 648 | ): 649 | super().__init__() 650 | # norm_layer = norm_layer or partial(FusedLayerNorm, eps=1e-6) 651 | self.num_classes = num_classes 652 | self.num_features = ( 653 | self.embed_dim 654 | ) = embed_dim # num_features for consistency with other models 655 | 656 | if hybrid_backbone is not None: 657 | self.patch_embed = HybridEmbed( 658 | hybrid_backbone, 659 | img_size=img_size, 660 | in_chans=in_chans, 661 | embed_dim=embed_dim, 662 | ) 663 | else: 664 | self.patch_embed = PatchEmbed( 665 | img_size=img_size, 666 | patch_size=patch_size, 667 | in_chans=in_chans, 668 | embed_dim=embed_dim, 669 | ) 670 | 671 | num_patches = self.patch_embed.num_patches 672 | self.out_indices = out_indices 673 | 674 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 675 | 676 | if use_abs_pos_emb: 677 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 678 | else: 679 | self.pos_embed = None 680 | 681 | self.pos_drop = nn.Dropout(p=drop_rate) 682 | 683 | if use_shared_rel_pos_bias: 684 | self.rel_pos_bias = RelativePositionBias( 685 | window_size=self.patch_embed.patch_shape, num_heads=num_heads 686 | ) 687 | else: 688 | self.rel_pos_bias = None 689 | 690 | if rope: 691 | half_head_dim = embed_dim // num_heads // 2 692 | hw_seq_len = img_size // patch_size 693 | self.rope = VisionRotaryEmbeddingFast( 694 | dim=half_head_dim, 695 | pt_seq_len=pt_hw_seq_len, 696 | ft_seq_len=hw_seq_len if intp_freq else None, 697 | ) 698 | else: 699 | self.rope = None 700 | 701 | self.naiveswiglu = naiveswiglu 702 | 703 | dpr = [ 704 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 705 | ] # stochastic depth decay rule 706 | self.use_rel_pos_bias = use_rel_pos_bias 707 | self.use_checkpoint = use_checkpoint 708 | self.blocks = nn.ModuleList( 709 | [ 710 | Block( 711 | dim=embed_dim, 712 | num_heads=num_heads, 713 | mlp_ratio=mlp_ratio, 714 | qkv_bias=qkv_bias, 715 | qk_scale=qk_scale, 716 | drop=drop_rate, 717 | attn_drop=attn_drop_rate, 718 | drop_path=dpr[i], 719 | norm_layer=norm_layer, 720 | init_values=init_values, 721 | window_size=self.patch_embed.patch_shape 722 | if use_rel_pos_bias 723 | else None, 724 | subln=subln, 725 | xattn=xattn, 726 | naiveswiglu=naiveswiglu, 727 | rope=self.rope, 728 | ) 729 | for i in range(depth) 730 | ] 731 | ) 732 | 733 | if self.pos_embed is not None: 734 | trunc_normal_(self.pos_embed, std=0.02) 735 | trunc_normal_(self.cls_token, std=0.02) 736 | 737 | # if patch_size == 16: 738 | # self.fpn1 = nn.Sequential( 739 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 740 | # nn.SyncBatchNorm(embed_dim), 741 | # nn.GELU(), 742 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 743 | # ) 744 | 745 | # self.fpn2 = nn.Sequential( 746 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 747 | # ) 748 | 749 | # self.fpn3 = nn.Identity() 750 | 751 | # self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) 752 | # elif patch_size == 8: 753 | # self.fpn1 = nn.Sequential( 754 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 755 | # ) 756 | 757 | # self.fpn2 = nn.Identity() 758 | 759 | # self.fpn3 = nn.Sequential( 760 | # nn.MaxPool2d(kernel_size=2, stride=2), 761 | # ) 762 | 763 | # self.fpn4 = nn.Sequential( 764 | # nn.MaxPool2d(kernel_size=4, stride=4), 765 | # ) 766 | # self.init_weights(pretrained) 767 | self.pretrained = pretrained 768 | 769 | def _init_weights(self, m): 770 | if isinstance(m, nn.Linear): 771 | trunc_normal_(m.weight, std=0.02) 772 | if isinstance(m, nn.Linear) and m.bias is not None: 773 | nn.init.constant_(m.bias, 0) 774 | elif isinstance(m, nn.LayerNorm): 775 | nn.init.constant_(m.bias, 0) 776 | nn.init.constant_(m.weight, 1.0) 777 | 778 | def init_weights(self): 779 | """Initialize the weights in backbone. 780 | 781 | Args: 782 | pretrained (str, optional): Path to pre-trained weights. 783 | Defaults to None. 784 | """ 785 | pretrained = self.pretrained 786 | 787 | def _init_weights(m): 788 | if isinstance(m, nn.Linear): 789 | trunc_normal_(m.weight, std=0.02) 790 | if isinstance(m, nn.Linear) and m.bias is not None: 791 | nn.init.constant_(m.bias, 0) 792 | elif isinstance(m, nn.LayerNorm): 793 | nn.init.constant_(m.bias, 0) 794 | nn.init.constant_(m.weight, 1.0) 795 | 796 | if isinstance(pretrained, str): 797 | self.apply(_init_weights) 798 | logger = MMLogger.get_current_instance() 799 | load_checkpoint(self, pretrained, strict=False, logger=logger) 800 | elif pretrained is None: 801 | self.apply(_init_weights) 802 | else: 803 | raise TypeError("pretrained must be a str or None") 804 | 805 | def get_num_layers(self): 806 | return len(self.blocks) 807 | 808 | @torch.jit.ignore 809 | def no_weight_decay(self): 810 | return {"pos_embed", "cls_token"} 811 | 812 | def forward_features(self, x): 813 | B, C, H, W = x.shape 814 | x, (Hp, Wp) = self.patch_embed(x) 815 | batch_size, seq_len, _ = x.size() 816 | 817 | cls_tokens = self.cls_token.expand( 818 | batch_size, -1, -1 819 | ) # stole cls_tokens impl from Phil Wang, thanks 820 | x = torch.cat((cls_tokens, x), dim=1) 821 | if self.pos_embed is not None: 822 | x = x + self.pos_embed 823 | x = self.pos_drop(x) 824 | 825 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 826 | features = [] 827 | for i, blk in enumerate(self.blocks): 828 | if self.use_checkpoint: 829 | x = checkpoint.checkpoint(blk, x, rel_pos_bias) 830 | else: 831 | x = blk(x, rel_pos_bias) 832 | if i in self.out_indices: 833 | xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) 834 | features.append(xp.contiguous()) 835 | features[0] = F.interpolate( 836 | features[0], scale_factor=4, mode="bilinear", align_corners=False 837 | ) 838 | features[1] = F.interpolate( 839 | features[1], scale_factor=2, mode="bilinear", align_corners=False 840 | ) 841 | features[3] = F.interpolate( 842 | features[3], scale_factor=0.5, mode="bilinear", align_corners=False 843 | ) 844 | 845 | return tuple(features) 846 | 847 | def forward(self, x): 848 | x = self.forward_features(x) 849 | return x 850 | -------------------------------------------------------------------------------- /rein/models/backbones/reins.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from functools import reduce 6 | from operator import mul 7 | from torch import Tensor 8 | 9 | class Reins(nn.Module): 10 | def __init__( 11 | self, 12 | num_layers: int, 13 | embed_dims: int, 14 | patch_size: int, 15 | query_dims: int = 256, 16 | token_length: int = 100, 17 | use_softmax: bool = True, 18 | link_token_to_query: bool = True, 19 | scale_init: float = 0.001, 20 | zero_mlp_delta_f: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.num_layers = num_layers 24 | self.embed_dims = embed_dims 25 | self.patch_size = patch_size 26 | self.query_dims = query_dims 27 | self.token_length = token_length 28 | self.link_token_to_query = link_token_to_query 29 | self.scale_init = scale_init 30 | self.use_softmax = use_softmax 31 | self.zero_mlp_delta_f = zero_mlp_delta_f 32 | self.create_model() 33 | 34 | def create_model(self): 35 | self.learnable_tokens = nn.Parameter( 36 | torch.empty([self.num_layers, self.token_length, self.embed_dims]) 37 | ) 38 | self.scale = nn.Parameter(torch.tensor(self.scale_init)) 39 | self.mlp_token2feat = nn.Linear(self.embed_dims, self.embed_dims) 40 | self.mlp_delta_f = nn.Linear(self.embed_dims, self.embed_dims) 41 | val = math.sqrt( 42 | 6.0 43 | / float( 44 | 3 * reduce(mul, (self.patch_size, self.patch_size), 1) + self.embed_dims 45 | ) 46 | ) 47 | nn.init.uniform_(self.learnable_tokens.data, -val, val) 48 | nn.init.kaiming_uniform_(self.mlp_delta_f.weight, a=math.sqrt(5)) 49 | nn.init.kaiming_uniform_(self.mlp_token2feat.weight, a=math.sqrt(5)) 50 | self.transform = nn.Linear(self.embed_dims, self.query_dims) 51 | self.merge = nn.Linear(self.query_dims * 3, self.query_dims) 52 | if self.zero_mlp_delta_f: 53 | del self.scale 54 | self.scale = 1.0 55 | nn.init.zeros_(self.mlp_delta_f.weight) 56 | nn.init.zeros_(self.mlp_delta_f.bias) 57 | 58 | def return_auto(self, feats): 59 | if self.link_token_to_query: 60 | tokens = self.transform(self.get_tokens(-1)).permute(1, 2, 0) 61 | tokens = torch.cat( 62 | [ 63 | F.max_pool1d(tokens, kernel_size=self.num_layers), 64 | F.avg_pool1d(tokens, kernel_size=self.num_layers), 65 | tokens[:, :, -1].unsqueeze(-1), 66 | ], 67 | dim=-1, 68 | ) 69 | querys = self.merge(tokens.flatten(-2, -1)) 70 | return feats, querys 71 | else: 72 | return feats 73 | 74 | def get_tokens(self, layer: int) -> Tensor: 75 | if layer == -1: 76 | # return all 77 | return self.learnable_tokens 78 | else: 79 | return self.learnable_tokens[layer] 80 | 81 | def forward( 82 | self, feats: Tensor, layer: int, batch_first=False, has_cls_token=True 83 | ) -> Tensor: 84 | if batch_first: 85 | feats = feats.permute(1, 0, 2) 86 | if has_cls_token: 87 | cls_token, feats = torch.tensor_split(feats, [1], dim=0) 88 | tokens = self.get_tokens(layer) 89 | delta_feat = self.forward_delta_feat( 90 | feats, 91 | tokens, 92 | layer, 93 | ) 94 | delta_feat = delta_feat * self.scale 95 | feats = feats + delta_feat 96 | if has_cls_token: 97 | feats = torch.cat([cls_token, feats], dim=0) 98 | if batch_first: 99 | feats = feats.permute(1, 0, 2) 100 | return feats 101 | 102 | def forward_delta_feat(self, feats: Tensor, tokens: Tensor, layers: int) -> Tensor: 103 | attn = torch.einsum("nbc,mc->nbm", feats, tokens) 104 | if self.use_softmax: 105 | attn = attn * (self.embed_dims**-0.5) 106 | attn = F.softmax(attn, dim=-1) 107 | delta_f = torch.einsum( 108 | "nbm,mc->nbc", 109 | attn[:, :, 1:], 110 | self.mlp_token2feat(tokens[1:, :]), 111 | ) 112 | delta_f = self.mlp_delta_f(delta_f + feats) 113 | return delta_f -------------------------------------------------------------------------------- /rein/models/backbones/reins_dinov2.py: -------------------------------------------------------------------------------- 1 | from .reins import Reins 2 | from .dino_v2 import DinoVisionTransformer 3 | from .utils import set_requires_grad, set_train 4 | 5 | 6 | class ReinsDinoVisionTransformer(DinoVisionTransformer): 7 | def __init__( 8 | self, 9 | **kwargs, 10 | ): 11 | super().__init__(**kwargs) 12 | self.reins = Reins( 13 | num_layers = kwargs['depth'], 14 | embed_dims = kwargs['embed_dim'], 15 | patch_size = kwargs['patch_size'], 16 | ) 17 | 18 | # self.reins2 = Reins( 19 | # num_layers = kwargs['depth'], 20 | # embed_dims = kwargs['embed_dim'], 21 | # patch_size = kwargs['patch_size'], 22 | # ) 23 | 24 | def forward_features(self, x, masks=None): 25 | B, _, h, w = x.shape 26 | H, W = h // self.patch_size, w // self.patch_size 27 | x = self.prepare_tokens_with_masks(x, masks) 28 | outs = [] 29 | 30 | for idx, blk in enumerate(self.blocks): 31 | x = blk(x) 32 | x = self.reins.forward( 33 | x, 34 | idx, 35 | batch_first=True, 36 | has_cls_token=True, 37 | ) 38 | return x 39 | 40 | def forward_features_full_rein(self, x, masks=None): 41 | B, _, h, w = x.shape 42 | H, W = h // self.patch_size, w // self.patch_size 43 | x = self.prepare_tokens_with_masks(x, masks) 44 | outs = [] 45 | for idx, blk in enumerate(self.blocks): 46 | x = blk(x) 47 | x = self.reins.forward( 48 | x, 49 | idx, 50 | batch_first=True, 51 | has_cls_token=True, 52 | ) 53 | if idx in self.out_indices: 54 | outs.append( 55 | x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H, W).contiguous() 56 | ) 57 | return self.reins.return_auto(outs) 58 | 59 | 60 | 61 | def forward_features_no_rein(self, x, masks=None): 62 | B, _, h, w = x.shape 63 | H, W = h // self.patch_size, w // self.patch_size 64 | x = self.prepare_tokens_with_masks(x, masks) 65 | outs = [] 66 | for idx, blk in enumerate(self.blocks): 67 | x = blk(x) 68 | return x 69 | 70 | def train(self, mode: bool = True): 71 | if not mode: 72 | return super().train(mode) 73 | set_requires_grad(self, ["reins", "linear"]) 74 | set_train(self, ["reins", "linear"]) 75 | 76 | 77 | 78 | class ReinsDinoVisionTransformer_3_head(DinoVisionTransformer): 79 | def __init__( 80 | self, 81 | **kwargs, 82 | ): 83 | super().__init__(**kwargs) 84 | self.reins1 = Reins( 85 | num_layers = kwargs['depth'], 86 | embed_dims = kwargs['embed_dim'], 87 | patch_size = kwargs['patch_size'], 88 | ) 89 | 90 | self.reins2 = Reins( 91 | num_layers = kwargs['depth'], 92 | embed_dims = kwargs['embed_dim'], 93 | patch_size = kwargs['patch_size'], 94 | ) 95 | 96 | def forward_features1(self, x, masks=None): 97 | B, _, h, w = x.shape 98 | H, W = h // self.patch_size, w // self.patch_size 99 | x = self.prepare_tokens_with_masks(x, masks) 100 | outs = [] 101 | 102 | for idx, blk in enumerate(self.blocks): 103 | x = blk(x) 104 | x = self.reins1.forward( 105 | x, 106 | idx, 107 | batch_first=True, 108 | has_cls_token=True, 109 | ) 110 | return x 111 | 112 | def forward_features2(self, x, masks=None): 113 | B, _, h, w = x.shape 114 | H, W = h // self.patch_size, w // self.patch_size 115 | x = self.prepare_tokens_with_masks(x, masks) 116 | outs = [] 117 | 118 | for idx, blk in enumerate(self.blocks): 119 | x = blk(x) 120 | x = self.reins2.forward( 121 | x, 122 | idx, 123 | batch_first=True, 124 | has_cls_token=True, 125 | ) 126 | return x 127 | 128 | def forward_features_no_rein(self, x, masks=None): 129 | B, _, h, w = x.shape 130 | H, W = h // self.patch_size, w // self.patch_size 131 | x = self.prepare_tokens_with_masks(x, masks) 132 | outs = [] 133 | for idx, blk in enumerate(self.blocks): 134 | x = blk(x) 135 | return x 136 | 137 | def train(self, mode: bool = True): 138 | if not mode: 139 | return super().train(mode) 140 | set_requires_grad(self, ["reins1", "reins2", "linear"]) 141 | set_train(self, ["reins1", "reins2", "linear"]) -------------------------------------------------------------------------------- /rein/models/backbones/reins_eva_02.py: -------------------------------------------------------------------------------- 1 | from .eva_02 import EVA2 2 | from mmseg.models.builder import BACKBONES, MODELS 3 | from .reins import Reins 4 | import torch 5 | import torch.utils.checkpoint as checkpoint 6 | import torch.nn.functional as F 7 | from .utils import set_requires_grad, set_train 8 | 9 | 10 | @BACKBONES.register_module() 11 | class ReinsEVA2(EVA2): 12 | def __init__(self, reins_config=None, **kwargs): 13 | super().__init__(**kwargs) 14 | self.reins: Reins = MODELS.build(reins_config) 15 | 16 | def forward_features(self, x): 17 | B, C, H, W = x.shape 18 | x, (Hp, Wp) = self.patch_embed(x) 19 | batch_size, seq_len, _ = x.size() 20 | 21 | cls_tokens = self.cls_token.expand( 22 | batch_size, -1, -1 23 | ) # stole cls_tokens impl from Phil Wang, thanks 24 | x = torch.cat((cls_tokens, x), dim=1) 25 | if self.pos_embed is not None: 26 | x = x + self.pos_embed 27 | x = self.pos_drop(x) 28 | 29 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 30 | features = [] 31 | for i, blk in enumerate(self.blocks): 32 | if self.use_checkpoint: 33 | x = checkpoint.checkpoint(blk, x, rel_pos_bias) 34 | else: 35 | x = blk(x, rel_pos_bias) 36 | x = self.reins.forward( 37 | x, 38 | i, 39 | batch_first=True, 40 | has_cls_token=True, 41 | ) 42 | if i in self.out_indices: 43 | xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) 44 | features.append(xp.contiguous()) 45 | features[0] = F.interpolate( 46 | features[0], scale_factor=4, mode="bilinear", align_corners=False 47 | ) 48 | features[1] = F.interpolate( 49 | features[1], scale_factor=2, mode="bilinear", align_corners=False 50 | ) 51 | features[3] = F.interpolate( 52 | features[3], scale_factor=0.5, mode="bilinear", align_corners=False 53 | ) 54 | return self.reins.return_auto(features) 55 | 56 | def train(self, mode: bool = True): 57 | if not mode: 58 | return super().train(mode) 59 | set_requires_grad(self, ["reins"]) 60 | set_train(self, ["reins"]) 61 | 62 | def state_dict(self, destination, prefix, keep_vars): 63 | state = super().state_dict(destination, prefix, keep_vars) 64 | keys = [k for k in state.keys() if "rein" not in k] 65 | for key in keys: 66 | state.pop(key) 67 | if key in destination: 68 | destination.pop(key) 69 | return state 70 | -------------------------------------------------------------------------------- /rein/models/backbones/reins_resnet.py: -------------------------------------------------------------------------------- 1 | from .reins import Reins 2 | from .utils import set_requires_grad, set_train 3 | from typing import List, Dict 4 | import torch.nn as nn 5 | import timm 6 | from timm.models.resnet import ResNet, Bottleneck 7 | 8 | # Modified from the code of https://github.com/w1oves/Rein/blob/train/rein/models/backbones/reins_resnet.py 9 | class ReinsResNet(ResNet): 10 | def __init__( 11 | self, 12 | **kwargs, 13 | ): 14 | model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3]) 15 | super().__init__(**dict(model_args, **kwargs)) 16 | self.reins: List[Reins] = nn.ModuleList() 17 | self.reins.append(Reins(num_layers=1, embed_dims=256, patch_size=1)) # For layer 1 18 | self.reins.append(Reins(num_layers=1, embed_dims=512, patch_size=1)) # For layer 1 19 | self.reins.append(Reins(num_layers=1, embed_dims=1024, patch_size=1)) # For layer 1 20 | self.reins.append(Reins(num_layers=1, embed_dims=2048, patch_size=1)) # For layer 1 21 | 22 | 23 | print('length of reins: ', len(self.reins)) 24 | def forward(self, x): 25 | x = self.conv1(x) 26 | x = self.bn1(x) 27 | x = self.act1(x) 28 | x = self.maxpool(x) 29 | outs = [] 30 | for i, layer_name in enumerate(['layer1', 'layer2', 'layer3', 'layer4']): 31 | res_layer = getattr(self, layer_name) 32 | # print(res_layer) 33 | x = res_layer(x) 34 | # print(x.shape) 35 | B, C, H, W = x.shape 36 | x = ( 37 | self.reins[i] 38 | .forward( 39 | x.flatten(-2, -1).permute(0, 2, 1), 40 | 0, 41 | batch_first=True, 42 | has_cls_token=False, 43 | ) 44 | .permute(0, 2, 1) 45 | .reshape(B, C, H, W) 46 | ) 47 | x = self.global_pool(x) 48 | x = self.fc(x) 49 | return x 50 | 51 | def train(self, mode: bool = True): 52 | if not mode: 53 | return super().train(mode) 54 | set_requires_grad(self, ["reins", "fc"]) 55 | set_train(self, ["reins", "fc"]) 56 | -------------------------------------------------------------------------------- /rein/models/backbones/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List 3 | # from mmengine.logging import MMLogger 4 | 5 | first_set_requires_grad = True 6 | first_set_train = True 7 | 8 | 9 | def set_requires_grad(model: nn.Module, keywords: List[str]): 10 | """ 11 | notice:key in name! 12 | """ 13 | requires_grad_names = [] 14 | num_params = 0 15 | num_trainable = 0 16 | for name, param in model.named_parameters(): 17 | num_params += param.numel() 18 | if any(key in name for key in keywords): 19 | param.requires_grad = True 20 | requires_grad_names.append(name) 21 | num_trainable += param.numel() 22 | else: 23 | param.requires_grad = False 24 | global first_set_requires_grad 25 | # if first_set_requires_grad: 26 | # # logger = MMLogger.get_current_instance() 27 | # for name in requires_grad_names: 28 | # logger.info(f"set_requires_grad----{name}") 29 | # logger.info( 30 | # f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%" 31 | # ) 32 | # first_set_requires_grad = False 33 | 34 | 35 | def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""): 36 | train_names = [] 37 | for name, child in model.named_children(): 38 | fullname = ".".join([prefix, name]) 39 | if any(name.startswith(key) for key in keywords): 40 | train_names.append(fullname) 41 | child.train() 42 | else: 43 | train_names += _set_train(child, keywords, prefix=fullname) 44 | return train_names 45 | 46 | 47 | def set_train(model: nn.Module, keywords: List[str]): 48 | """ 49 | notice:sub name startwith key! 50 | """ 51 | model.train(False) 52 | train_names = _set_train(model, keywords) 53 | # global first_set_train 54 | # if first_set_train: 55 | # logger = MMLogger.get_current_instance() 56 | # for train_name in train_names: 57 | # logger.info(f"set_train----{train_name}") 58 | # first_set_train = False -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | medmnist==3.0.1 2 | numpy==1.24.4 3 | nvidia-cublas-cu11==11.10.3.66 4 | nvidia-cuda-cupti-cu11==11.7.101 5 | nvidia-cuda-nvrtc-cu11==11.7.99 6 | nvidia-cuda-runtime-cu11==11.7.99 7 | nvidia-cudnn-cu11==8.5.0.96 8 | nvidia-cufft-cu11==10.9.0.58 9 | nvidia-curand-cu11==10.2.10.91 10 | nvidia-cusolver-cu11==11.4.0.1 11 | nvidia-cusparse-cu11==11.7.4.91 12 | nvidia-nccl-cu11==2.14.3 13 | nvidia-nvtx-cu11==11.7.91 14 | Pillow==10.0.0 15 | timm==0.6.13 16 | torch==2.0.1 17 | torchaudio==2.0.2 18 | torchdiffeq==0.2.3 19 | torchvision==0.15.2 -------------------------------------------------------------------------------- /train_cufit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | import timm 7 | import utils 8 | 9 | import rein 10 | 11 | import dino_variant 12 | 13 | 14 | def train(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data', '-d', type=str) 17 | parser.add_argument('--gpu', '-g', default = '0', type=str) 18 | parser.add_argument('--netsize', default='s', type=str) 19 | parser.add_argument('--save_path', '-s', type=str) 20 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2) 21 | args = parser.parse_args() 22 | 23 | config = utils.read_conf('conf/'+args.data+'.json') 24 | device = 'cuda:'+args.gpu 25 | save_path = os.path.join(config['save_path'], args.save_path) 26 | data_path = config['id_dataset'] 27 | batch_size = int(config['batch_size']) 28 | max_epoch = int(config['epoch']) 29 | noise_rate = args.noise_rate 30 | 31 | if not os.path.exists(save_path): 32 | os.mkdir(save_path) 33 | 34 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)] 35 | 36 | if args.data == 'ham10000': 37 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 38 | elif args.data == 'aptos': 39 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 40 | elif 'mnist' in args.data: 41 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size) 42 | elif 'cifar' in args.data: 43 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate) 44 | 45 | if args.netsize == 's': 46 | model_load = dino_variant._small_dino 47 | variant = dino_variant._small_variant 48 | elif args.netsize == 'b': 49 | model_load = dino_variant._base_dino 50 | variant = dino_variant._base_variant 51 | elif args.netsize == 'l': 52 | model_load = dino_variant._large_dino 53 | variant = dino_variant._large_variant 54 | # model = timm.create_model(network, pretrained=True, num_classes=2) 55 | model = torch.hub.load('facebookresearch/dinov2', model_load) 56 | dino_state_dict = model.state_dict() 57 | 58 | model = rein.ReinsDinoVisionTransformer( 59 | **variant 60 | ) 61 | model.load_state_dict(dino_state_dict, strict=False) 62 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes']) 63 | model.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes']) 64 | model.to(device) 65 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 66 | model.eval() 67 | 68 | model2 = rein.ReinsDinoVisionTransformer( 69 | **variant 70 | ) 71 | model2.load_state_dict(dino_state_dict, strict=False) 72 | model2.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes']) 73 | model2.to(device) 74 | 75 | model.eval() 76 | model2.eval() 77 | 78 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-5) 79 | optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-3, weight_decay = 1e-5) 80 | 81 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay) 82 | scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer2, lr_decay) 83 | saver = timm.utils.CheckpointSaver(model2, optimizer, checkpoint_dir= save_path, max_history = 1) 84 | print(train_loader.dataset[0][0].shape) 85 | 86 | print('## Trainable parameters') 87 | model2.train() 88 | for n, p in model2.named_parameters(): 89 | if p.requires_grad == True: 90 | print(n) 91 | 92 | avg_accuracy = 0.0 93 | for epoch in range(max_epoch): 94 | ## training 95 | model.train() 96 | model2.train() 97 | total_loss = 0 98 | total = 0 99 | correct = 0 100 | correct2 = 0 101 | correct_linear = 0 102 | for batch_idx, (inputs, targets) in enumerate(train_loader): 103 | inputs, targets = inputs.to(device), targets.to(device) 104 | optimizer.zero_grad() 105 | 106 | features_rein = model.forward_features(inputs) 107 | features_rein = features_rein[:, 0, :] 108 | outputs = model.linear_rein(features_rein) 109 | 110 | features_rein2 = model2.forward_features(inputs) 111 | features_rein2 = features_rein2[:, 0, :] 112 | outputs2 = model2.linear_rein(features_rein2) 113 | 114 | with torch.no_grad(): 115 | features_ = model.forward_features_no_rein(inputs) 116 | features_ = features_[:, 0, :] 117 | outputs_ = model.linear(features_) 118 | # print(outputs.shape, outputs_.shape) 119 | 120 | with torch.no_grad(): 121 | pred = (outputs_).max(1).indices 122 | linear_accurate = (pred==targets) 123 | 124 | pred2 = outputs.max(1).indices 125 | linear_accurate2 = (pred2==targets) 126 | 127 | loss_rein = linear_accurate*criterion(outputs, targets) 128 | loss_rein2 = linear_accurate2*criterion(outputs2, targets) 129 | loss_linear = criterion(outputs_, targets) 130 | loss = loss_linear.mean()+loss_rein.mean() 131 | loss.backward() 132 | optimizer.step() # + outputs_ 133 | 134 | optimizer2.zero_grad() 135 | loss_rein2.mean().backward() 136 | optimizer2.step() 137 | 138 | total_loss += loss 139 | total += targets.size(0) 140 | _, predicted = outputs[:len(targets)].max(1) 141 | correct += predicted.eq(targets).sum().item() 142 | 143 | _, predicted = outputs2[:len(targets)].max(1) 144 | correct2 += predicted.eq(targets).sum().item() 145 | 146 | _, predicted = outputs_[:len(targets)].max(1) 147 | correct_linear += predicted.eq(targets).sum().item() 148 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc2: %.3f%% | Acc1: %.3f%% | LinearAcc: %.3f%% | (%d/%d)' 149 | % (total_loss/(batch_idx+1), 100.*correct2/total, 100.*correct/total, 100.*correct_linear/total, correct, total), end = '') 150 | train_accuracy = correct/total 151 | train_avg_loss = total_loss/len(train_loader) 152 | print() 153 | 154 | ## validation 155 | model.eval() 156 | model2.eval() 157 | 158 | total_loss = 0 159 | total = 0 160 | correct = 0 161 | valid_accuracy = utils.validation_accuracy(model2, valid_loader, device) 162 | valid_accuracy_ = utils.validation_accuracy(model, valid_loader, device) 163 | valid_accuracy_linear = utils.validation_accuracy(model, valid_loader, device, mode='no_rein') 164 | 165 | scheduler.step() 166 | scheduler2.step() 167 | if epoch >= max_epoch-10: 168 | avg_accuracy += valid_accuracy 169 | saver.save_checkpoint(epoch, metric = valid_accuracy) 170 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID_2 [acc - {:.4f}], VALID_1 [acc - {:.4f}], VALID(linear) [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy, valid_accuracy_, valid_accuracy_linear)) 171 | print(scheduler.get_last_lr()) 172 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f: 173 | f.write(str(avg_accuracy/10)) 174 | if __name__ =='__main__': 175 | train() -------------------------------------------------------------------------------- /train_fully.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | import timm 7 | import utils 8 | 9 | import dino_variant 10 | 11 | 12 | def train(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data', '-d', type=str) 15 | parser.add_argument('--gpu', '-g', default = '0', type=str) 16 | parser.add_argument('--netsize', default='s', type=str) 17 | parser.add_argument('--save_path', '-s', type=str) 18 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2) 19 | args = parser.parse_args() 20 | 21 | config = utils.read_conf('conf/'+args.data+'.json') 22 | device = 'cuda:'+args.gpu 23 | save_path = os.path.join(config['save_path'], args.save_path) 24 | data_path = config['id_dataset'] 25 | batch_size = int(config['batch_size']) 26 | max_epoch = int(config['epoch']) 27 | noise_rate = args.noise_rate 28 | 29 | if not os.path.exists(save_path): 30 | os.mkdir(save_path) 31 | 32 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)] 33 | 34 | if args.data == 'ham10000': 35 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 36 | elif args.data == 'aptos': 37 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 38 | elif 'mnist' in args.data: 39 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size) 40 | elif 'cifar' in args.data: 41 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate) 42 | 43 | if args.netsize == 's': 44 | model_load = dino_variant._small_dino 45 | variant = dino_variant._small_variant 46 | elif args.netsize == 'b': 47 | model_load = dino_variant._base_dino 48 | variant = dino_variant._base_variant 49 | elif args.netsize == 'l': 50 | model_load = dino_variant._large_dino 51 | variant = dino_variant._large_variant 52 | 53 | model = torch.hub.load('facebookresearch/dinov2', model_load) 54 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes']) 55 | model.to(device) 56 | 57 | criterion = torch.nn.CrossEntropyLoss() 58 | model.eval() 59 | 60 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay = 1e-5) 61 | 62 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay) 63 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1) 64 | print(train_loader.dataset[0][0].shape) 65 | 66 | print('## Trainable parameters') 67 | model.train() 68 | for n, p in model.named_parameters(): 69 | if p.requires_grad == True: 70 | print(n) 71 | 72 | avg_accuracy = 0.0 73 | for epoch in range(max_epoch): 74 | ## training 75 | model.train() 76 | total_loss = 0 77 | total = 0 78 | correct = 0 79 | for batch_idx, (inputs, targets) in enumerate(train_loader): 80 | inputs, targets = inputs.to(device), targets.to(device) 81 | optimizer.zero_grad() 82 | 83 | outputs = model(inputs) 84 | outputs = model.linear(outputs) 85 | 86 | loss = criterion(outputs, targets) 87 | loss.backward() 88 | optimizer.step() 89 | 90 | total_loss += loss 91 | total += targets.size(0) 92 | _, predicted = outputs[:len(targets)].max(1) 93 | correct += predicted.eq(targets).sum().item() 94 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 95 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '') 96 | train_accuracy = correct/total 97 | train_avg_loss = total_loss/len(train_loader) 98 | print() 99 | 100 | ## validation 101 | model.eval() 102 | total_loss = 0 103 | total = 0 104 | correct = 0 105 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device, mode='linear') 106 | scheduler.step() 107 | if epoch >= max_epoch-10: 108 | avg_accuracy += valid_accuracy 109 | saver.save_checkpoint(epoch, metric = valid_accuracy) 110 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy)) 111 | print(scheduler.get_last_lr()) 112 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f: 113 | f.write(str(avg_accuracy/10)) 114 | 115 | if __name__ =='__main__': 116 | train() -------------------------------------------------------------------------------- /train_linear.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | import timm 7 | import utils 8 | 9 | import dino_variant 10 | 11 | 12 | def train(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data', '-d', type=str) 15 | parser.add_argument('--gpu', '-g', default = '0', type=str) 16 | parser.add_argument('--netsize', default='s', type=str) 17 | parser.add_argument('--save_path', '-s', type=str) 18 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2) 19 | args = parser.parse_args() 20 | 21 | config = utils.read_conf('conf/'+args.data+'.json') 22 | device = 'cuda:'+args.gpu 23 | save_path = os.path.join(config['save_path'], args.save_path) 24 | data_path = config['id_dataset'] 25 | batch_size = int(config['batch_size']) 26 | max_epoch = int(config['epoch']) 27 | noise_rate = args.noise_rate 28 | 29 | if not os.path.exists(save_path): 30 | os.mkdir(save_path) 31 | 32 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)] 33 | 34 | 35 | if args.data == 'ham10000': 36 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 37 | elif args.data == 'aptos': 38 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 39 | elif 'mnist' in args.data: 40 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size) 41 | 42 | if args.netsize == 's': 43 | model_load = dino_variant._small_dino 44 | variant = dino_variant._small_variant 45 | elif args.netsize == 'b': 46 | model_load = dino_variant._base_dino 47 | variant = dino_variant._base_variant 48 | elif args.netsize == 'l': 49 | model_load = dino_variant._large_dino 50 | variant = dino_variant._large_variant 51 | 52 | model = torch.hub.load('facebookresearch/dinov2', model_load) 53 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes']) 54 | model.to(device) 55 | 56 | criterion = torch.nn.CrossEntropyLoss() 57 | model.eval() 58 | 59 | for n, p in model.named_parameters(): 60 | if not 'linear' in n: 61 | p.requires_grad = False 62 | optimizer = torch.optim.Adam(model.linear.parameters(), lr=1e-3, weight_decay = 1e-5) 63 | 64 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay) 65 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1) 66 | print(train_loader.dataset[0][0].shape) 67 | 68 | print('## Trainable parameters') 69 | model.train() 70 | for n, p in model.named_parameters(): 71 | if p.requires_grad == True: 72 | print(n) 73 | avg_accuracy = 0.0 74 | for epoch in range(max_epoch): 75 | ## training 76 | model.train() 77 | total_loss = 0 78 | total = 0 79 | correct = 0 80 | for batch_idx, (inputs, targets) in enumerate(train_loader): 81 | inputs, targets = inputs.to(device), targets.to(device) 82 | optimizer.zero_grad() 83 | 84 | with torch.no_grad(): 85 | outputs = model(inputs) 86 | outputs = model.linear(outputs) 87 | 88 | loss = criterion(outputs, targets) 89 | loss.backward() 90 | optimizer.step() 91 | 92 | total_loss += loss 93 | total += targets.size(0) 94 | _, predicted = outputs[:len(targets)].max(1) 95 | correct += predicted.eq(targets).sum().item() 96 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 97 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '') 98 | train_accuracy = correct/total 99 | 100 | train_avg_loss = total_loss/len(train_loader) 101 | print() 102 | 103 | ## validation 104 | model.eval() 105 | total_loss = 0 106 | total = 0 107 | correct = 0 108 | 109 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device, mode='linear') 110 | if epoch >= max_epoch-10: 111 | avg_accuracy += valid_accuracy 112 | scheduler.step() 113 | 114 | saver.save_checkpoint(epoch, metric = valid_accuracy) 115 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy)) 116 | print(scheduler.get_last_lr()) 117 | 118 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f: 119 | f.write(str(avg_accuracy/10)) 120 | 121 | if __name__ =='__main__': 122 | train() -------------------------------------------------------------------------------- /train_rein.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | import argparse 6 | import timm 7 | import utils 8 | 9 | import rein 10 | 11 | import dino_variant 12 | 13 | def train(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data', '-d', type=str) 16 | parser.add_argument('--gpu', '-g', default = '0', type=str) 17 | parser.add_argument('--netsize', default='s', type=str) 18 | parser.add_argument('--save_path', '-s', type=str) 19 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2) 20 | args = parser.parse_args() 21 | 22 | config = utils.read_conf('conf/'+args.data+'.json') 23 | device = 'cuda:'+args.gpu 24 | save_path = os.path.join(config['save_path'], args.save_path) 25 | data_path = config['id_dataset'] 26 | batch_size = int(config['batch_size']) 27 | max_epoch = int(config['epoch']) 28 | noise_rate = args.noise_rate 29 | 30 | if not os.path.exists(save_path): 31 | os.mkdir(save_path) 32 | 33 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)] 34 | 35 | 36 | if args.data == 'ham10000': 37 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 38 | elif args.data == 'aptos': 39 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size) 40 | elif 'mnist' in args.data: 41 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size) 42 | elif 'cifar' in args.data: 43 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate) 44 | 45 | if args.netsize == 's': 46 | model_load = dino_variant._small_dino 47 | variant = dino_variant._small_variant 48 | elif args.netsize == 'b': 49 | model_load = dino_variant._base_dino 50 | variant = dino_variant._base_variant 51 | elif args.netsize == 'l': 52 | model_load = dino_variant._large_dino 53 | variant = dino_variant._large_variant 54 | 55 | 56 | model = torch.hub.load('facebookresearch/dinov2', model_load) 57 | dino_state_dict = model.state_dict() 58 | 59 | model = rein.ReinsDinoVisionTransformer( 60 | **variant 61 | ) 62 | model.load_state_dict(dino_state_dict, strict=False) 63 | model.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes']) 64 | model.to(device) 65 | 66 | criterion = torch.nn.CrossEntropyLoss() 67 | model.eval() 68 | 69 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-5) 70 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay) 71 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1) 72 | print(train_loader.dataset[0][0].shape) 73 | 74 | print('## Trainable parameters') 75 | model.train() 76 | for n, p in model.named_parameters(): 77 | if p.requires_grad == True: 78 | print(n) 79 | 80 | avg_accuracy = 0.0 81 | for epoch in range(max_epoch): 82 | ## training 83 | model.train() 84 | total_loss = 0 85 | total = 0 86 | correct = 0 87 | for batch_idx, (inputs, targets) in enumerate(train_loader): 88 | inputs, targets = inputs.to(device), targets.to(device) 89 | optimizer.zero_grad() 90 | 91 | features = model.forward_features(inputs) 92 | features = features[:, 0, :] 93 | outputs = model.linear_rein(features) 94 | loss = criterion(outputs, targets) 95 | loss.backward() 96 | optimizer.step() 97 | 98 | total_loss += loss 99 | total += targets.size(0) 100 | _, predicted = outputs[:len(targets)].max(1) 101 | correct += predicted.eq(targets).sum().item() 102 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 103 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '') 104 | train_accuracy = correct/total 105 | 106 | train_avg_loss = total_loss/len(train_loader) 107 | print() 108 | 109 | ## validation 110 | model.eval() 111 | total_loss = 0 112 | total = 0 113 | correct = 0 114 | 115 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device) 116 | if epoch >= max_epoch-10: 117 | avg_accuracy += valid_accuracy 118 | scheduler.step() 119 | 120 | saver.save_checkpoint(epoch, metric = valid_accuracy) 121 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy)) 122 | print(scheduler.get_last_lr()) 123 | 124 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f: 125 | f.write(str(avg_accuracy/10)) 126 | 127 | if __name__ =='__main__': 128 | train() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .metric import * -------------------------------------------------------------------------------- /utils/aptos.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | 4 | from PIL import Image 5 | 6 | class APTOS2019(): 7 | def __init__(self, root_dir, train = True, transforms=None): 8 | """ 9 | Arguments: 10 | csv_file (string): Path to the csv file with annotations. 11 | root_dir (string): Directory with all the images. 12 | transform (callable, optional): Optional transform to be applied 13 | on a sample. 14 | """ 15 | self.root_dir = root_dir 16 | self.transform = transforms 17 | 18 | self.label_txt = os.path.join(root_dir, 'train_1.csv' if train else 'test.csv') 19 | 20 | self.samples = [] 21 | with open(self.label_txt, 'r') as f: 22 | lines = f.readlines() 23 | for line in lines[1:]: 24 | line = line.split(',') 25 | if len(line) == 2: 26 | img_name, label = line 27 | 28 | img_name = os.path.join(root_dir, 'train_images/train_images' if train else 'test_images', img_name+'.png') 29 | label = label.replace('\n', '') 30 | label = int(label) 31 | 32 | self.samples.append([img_name, label]) 33 | 34 | def __len__(self): 35 | return len(self.samples) 36 | 37 | def __getitem__(self, idx): 38 | sample, label = self.samples[idx] 39 | sample = Image.open(sample) 40 | 41 | if self.transform: 42 | sample = self.transform(sample) 43 | 44 | return sample, label 45 | 46 | if __name__ == '__main__': 47 | aptos = APTOS2019('./data/APTOS-2019', True) 48 | print(aptos[0]) -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | 5 | from PIL import Image 6 | from torchvision import transforms 7 | from torchvision import datasets as dset 8 | import torchvision 9 | 10 | from .aptos import APTOS2019 11 | 12 | def get_transform(transform_type='default', image_size=224, args=None): 13 | 14 | if transform_type == 'default': 15 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 16 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 17 | 18 | mean = IMAGENET_DEFAULT_MEAN 19 | std = IMAGENET_DEFAULT_STD 20 | 21 | 22 | train_transform = transforms.Compose([ 23 | transforms.Resize((256, 256)), 24 | # transforms.Resize((224, 224)), 25 | transforms.RandomHorizontalFlip(p=0.5), 26 | # transforms.RandomVerticalFlip(p=0.5), 27 | # transforms.ColorJitter(), 28 | transforms.RandomCrop(size=(image_size, image_size)), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=mean, std=std) 31 | ]) 32 | 33 | test_transform = transforms.Compose([ 34 | transforms.Resize((256, 256)), 35 | transforms.CenterCrop((image_size, image_size)), 36 | transforms.ToTensor(), 37 | transforms.Normalize(mean=mean, std=std) 38 | ]) 39 | return train_transform, test_transform 40 | 41 | def read_conf(json_path): 42 | """ 43 | read json and return the configure as dictionary. 44 | """ 45 | with open(json_path) as json_file: 46 | config = json.load(json_file) 47 | return config 48 | 49 | def get_noise_dataset(path, noise_rate = 0.2, batch_size = 32, seed = 0): 50 | train_transform, test_transform = get_transform() 51 | train_data = torchvision.datasets.ImageFolder(path + '/train', train_transform) 52 | np.random.seed(seed) 53 | 54 | new_data = [] 55 | for i in range(len(train_data.samples)): 56 | if np.random.rand() > noise_rate: # clean sample: 57 | new_data.append([train_data.samples[i][0], train_data.samples[i][1]]) 58 | else: 59 | label_index = list(range(7)) 60 | label_index.remove(train_data.samples[i][1]) 61 | label_index = np.array(label_index) 62 | label_index = np.reshape(label_index, (-1)) 63 | 64 | new_label = np.random.choice(label_index, 1) 65 | new_label = new_label[0] 66 | 67 | new_data.append([train_data.samples[i][0], new_label]) 68 | train_data.samples = new_data 69 | 70 | # Testing 71 | with open('label.txt', 'w') as f: 72 | for i in range(len(train_data.samples)): 73 | f.write('{}\n'.format(train_data.samples[i][1])) 74 | 75 | valid_data = torchvision.datasets.ImageFolder(path + '/test', test_transform) 76 | 77 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 8) 78 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8) 79 | return train_loader, valid_loader 80 | 81 | def get_aptos_noise_dataset(path, noise_rate = 0.2, batch_size = 32, seed = 0): 82 | train_transform, test_transform = get_transform() 83 | train_data = APTOS2019(path, train=True, transforms = train_transform) 84 | 85 | np.random.seed(seed) 86 | new_data = [] 87 | for i in range(len(train_data.samples)): 88 | if np.random.rand() > noise_rate: # clean sample: 89 | new_data.append([train_data.samples[i][0], train_data.samples[i][1]]) 90 | else: 91 | label_index = list(range(5)) 92 | label_index.remove(train_data.samples[i][1]) 93 | label_index = np.array(label_index) 94 | label_index = np.reshape(label_index, (-1)) 95 | 96 | new_label = np.random.choice(label_index, 1) 97 | new_label = new_label[0] 98 | 99 | new_data.append([train_data.samples[i][0], new_label]) 100 | train_data.samples = new_data 101 | 102 | # Testing 103 | with open('label.txt', 'w') as f: 104 | for i in range(len(train_data.samples)): 105 | f.write('{}\n'.format(train_data.samples[i][1])) 106 | 107 | valid_data = APTOS2019(path, train=False, transforms = test_transform) 108 | 109 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 16) 110 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8) 111 | return train_loader, valid_loader 112 | 113 | 114 | def get_mnist_noise_dataset(dataname, noise_rate = 0.2, batch_size = 32, seed = 0): 115 | # from medmnist import NoduleMNIST3D 116 | from medmnist import PathMNIST, BloodMNIST, OCTMNIST, TissueMNIST, OrganCMNIST 117 | train_transform, test_transform = get_transform() 118 | 119 | if dataname == 'pathmnist': 120 | train_data = PathMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True) 121 | test_data = PathMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True) 122 | num_classes = 9 123 | if dataname == 'bloodmnist': 124 | train_data = BloodMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True) 125 | test_data = BloodMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True) 126 | num_classes = 8 127 | if dataname == 'octmnist': 128 | train_data = OCTMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True) 129 | test_data = OCTMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True) 130 | num_classes = 4 131 | if dataname == 'tissuemnist': 132 | train_data = TissueMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True) 133 | test_data = TissueMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True) 134 | num_classes = 8 135 | if dataname == 'organcmnist': 136 | train_data = OrganCMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True) 137 | test_data = OrganCMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True) 138 | num_classes = 11 139 | 140 | np.random.seed(seed) 141 | # new_imgs = [] 142 | new_labels =[] 143 | for i in range(len(train_data.imgs)): 144 | if np.random.rand() > noise_rate: # clean sample: 145 | # new_imgs.append(train_data.imgs[i]) 146 | new_labels.append(train_data.labels[i][0]) 147 | else: 148 | label_index = list(range(num_classes)) 149 | label_index.remove(train_data.labels[i]) 150 | label_index = np.array(label_index) 151 | label_index = np.reshape(label_index, (-1)) 152 | 153 | new_label = np.random.choice(label_index, 1) 154 | new_label = new_label[0] 155 | 156 | # new_imgs.append(train_data.imgs[i]) 157 | new_labels.append(new_label) 158 | # train_data.imgs = new_imgs 159 | train_data.labels = new_labels 160 | 161 | new_labels = [] 162 | for i in range(len(test_data.labels)): 163 | new_labels.append(test_data.labels[i][0]) 164 | test_data.labels = new_labels 165 | 166 | 167 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 16) 168 | valid_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8) 169 | return train_loader, valid_loader -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | recall_level_default = 0.95 4 | 5 | def validation_accuracy(model, loader, device, mode = 'rein'): 6 | total = 0 7 | correct = 0 8 | 9 | def linear(model, inputs): 10 | f = model(inputs) 11 | outputs = model.linear(f) 12 | return outputs 13 | 14 | def rein(model, inputs): 15 | f = model.forward_features(inputs) 16 | f = f[:, 0, :] 17 | outputs = model.linear_rein(f) 18 | return outputs 19 | 20 | def no_rein(model, inputs): 21 | f = model.forward_features_no_rein(inputs) 22 | f = f[:, 0, :] 23 | outputs = model.linear(f) 24 | return outputs 25 | if mode == 'rein': 26 | out = rein 27 | elif mode == 'no_rein': 28 | out = no_rein 29 | else: 30 | out = linear 31 | 32 | model.eval() 33 | with torch.no_grad(): 34 | for batch_idx, (inputs, targets) in enumerate(loader): 35 | inputs, targets = inputs.to(device), targets.to(device) 36 | outputs = out(model, inputs) 37 | _, predicted = outputs.max(1) 38 | correct += predicted.eq(targets).sum().item() 39 | total += targets.size(0) 40 | valid_accuracy = correct/total 41 | return valid_accuracy --------------------------------------------------------------------------------