├── .gitignore
├── README.md
├── assets
├── intro.png
├── logo.png
└── poster.png
├── conf
├── aptos.json
├── bloodmnist.json
├── ham10000.json
└── organcmnist.json
├── data
└── ham10000.py
├── dino_variant.py
├── rein
├── __init__.py
└── models
│ ├── __init__.py
│ └── backbones
│ ├── __init__.py
│ ├── beit.py
│ ├── clip.py
│ ├── dino_layers
│ ├── __init__.py
│ ├── attention.py
│ ├── block.py
│ ├── dino_head.py
│ ├── drop_path.py
│ ├── layer_scale.py
│ ├── mlp.py
│ ├── patch_embed.py
│ └── swiglu_ffn.py
│ ├── dino_v2.py
│ ├── eva_02.py
│ ├── reins.py
│ ├── reins_dinov2.py
│ ├── reins_eva_02.py
│ ├── reins_resnet.py
│ └── utils.py
├── requirement.txt
├── train_cufit.py
├── train_fully.py
├── train_linear.py
├── train_rein.py
└── utils
├── __init__.py
├── aptos.py
├── dataset.py
└── metric.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.tar
2 | *.pyc
3 | *.pth
4 | *.zip
5 | *.jpg
6 |
7 | data/ham10000/*
8 | data/aptos-2019/*
9 | checkpoints/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Curriculum Fine-tuning of Vision Foundation Model for Medical Image Classification Under Label Noise
4 |
5 | Yeonguk Yu
6 | ·
7 | Minhwan Ko
8 | ·
9 | Sungho Shin
10 | ·
11 | Kangmin Kim
12 | ·
13 | Kyoobin Lee
14 |
15 | Artificial Intelligence LAB
16 | GIST, South Korea
17 |
18 | NeurIPS 2024 - Poster Presentation
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | ---
34 |
35 |
36 |
37 |
38 |
39 |
40 | **TL;DR**: We propose CUFIT, a robust fine-tuning method for vision foundation models under noisy label conditions, based on the advantages of linear probing and adapters.
41 |
42 |
43 |
44 | Our **CU**rriculum **FI**ne-**T**uning of Vision Foundation Model **(CUFIT)** offers a robust training framework for medical multi-class image classification under noisy label conditions.
45 | Leveraging vision foundation models (VFMs) pretrained on large-scale datasets, CUFIT effectively handles noisy labels without modifying the feature extractor, using linear probing. Subsequently, it employs a curriculum fine-tuning approach, beginning with linear probing to ensure robustness to noisy samples, followed by fine-tuning two adapters for enhanced classification performance. CUFIT outperforms conventional methods across various medical image benchmarks, achieving superior results at various noise rates on datasets such as HAM10000 and APTOS-2019, highlighting its capability to address the challenges posed by noisy labels in medical datasets.
46 |
47 |
48 | ## 🚀 Getting Started
49 | ### Clone the Repository
50 | ```bash
51 | git clone https://github.com/gist-ailab/CUFIT.git
52 | cd CUFIT
53 | ```
54 |
55 | ### Environment Setup
56 | This code is tested under Linux 20.04 and Python 3.8.18 environment, and the code requires following main packages to be installed:
57 |
58 | - [Pytorch](https://pytorch.org/): Tested under 2.0.1 version of Pytorch-GPU.
59 | - [torchvision](https://pytorch.org/vision/stable/index.html): which will be installed along Pytorch. Tested under 0.15.2 version.
60 | - [MedMNIST](https://medmnist.com/): which is needed for experiments with BloodMnist, OrgancMnist. Tested under 3.0.1 version.
61 |
62 | you may use the follwoing lines.
63 | ```bash
64 | conda create -n cufit python=3.8
65 | conda activate cufit
66 | pip install -r requirement.txt
67 | ```
68 |
69 |
70 | ### Dataset Preparation
71 | Some public datasets are required to be downloaded for running experiments.
72 |
73 | HAM10000 preparation
74 |
75 | 1. Download the training data, training ground truth, Test data, Test ground truth of task 3 in this link.
76 |
77 | 2. Place the zip files in "CUFIT/data" folder and extract them.
78 |
79 | 3. Run the python code "ham10000.py" in "CUFIT/data".
80 |
81 | 4. This will create a folder named "ham10000" where images are sorted by its corrseponding disease.
82 |
83 |
84 |
85 | APTOS-2019 preparation
86 |
87 | 1. Download the zip files by clicking "download all" button in kaggle site.
88 |
89 | 2. Place the zip files in "CUFIT/data" folder and extract it.
90 |
91 | 3. Create a folder named "APTOS-2019" in "CUFIT/data".
92 |
93 | 4. Place the extracted files in the "APTOS-2019" folder.
94 |
95 |
96 |
97 | ### Config file may need to be changed for your path to download. For example,
98 | ~~~
99 | # conf/ham10000.json
100 | {
101 | "epoch" : "100",
102 | "id_dataset" : "./data/ham10000", # Your path to dataset
103 | "batch_size" : 32,
104 | "save_path" : "./checkpoints/ham10000", # Your path to checkpoint
105 | "num_classes" : 7
106 | }
107 | ~~~
108 |
109 |
110 | Place the data and create checkpoint folder following this directory structure:
111 | ```plaintext
112 | CUFIT/
113 | ├── assets/
114 | ├── checkpoints/
115 | ├── HAM10000/
116 | └── APTOS-2019/
117 | ├── conf/
118 | ├── HAM10000.json
119 | └── aptos.json
120 | ├── data/
121 | ├── HAM10000/
122 | ├── test/
123 | └── train/
124 | └── APTOS-2019
125 | ├── test_images/
126 | ├── train_images/
127 | ├── val_images/
128 | ├── test.csv
129 | ├── train_1.csv
130 | └── valid.csv
131 | ├── rein/
132 | └── utils/
133 | ```
134 |
135 | ---
136 | ## How to Run
137 | ### - To train a model by the linear probing with DINOv2-small architecture
138 | ~~~
139 | python train_linear.py -d 'data_name' -g 'gpu-num' -n 'noise_rate' -s 'save_name'
140 | ~~~
141 | for example,
142 | ~~~
143 | python train_linear.py -d ham10000 -g 0 -n 0.2 -s dinov2s_linear_0.2
144 | ~~~
145 |
146 |
147 | ### - To train a model by a single rein adapter with DINOv2-small architecture
148 | ~~~
149 | python train_rein.py -d 'data_name' -g 'gpu-num' -n 'noise_rate -s 'save_name'
150 | ~~~
151 | for example,
152 | ~~~
153 | python train_rein.py -d ham10000 -g 0 -n 0.2 -s dinov2s_single_rein_0.2
154 | ~~~
155 |
156 |
157 | ### - To train a model by the CUFIT with DINOv2-small architecture
158 | ~~~
159 | python train_cuft.py -d 'data_name' -g 'gpu-num' -n 'noise_rate -s 'save_name'
160 | ~~~
161 | for example,
162 | ~~~
163 | python train_cufit.py -d ham10000 -g 0 -n 0.2 -s dinov2s_cufit_0.2
164 | ~~~
165 |
166 |
167 | ## 🤝 Acknowledgements & Support
168 | This work waspartly supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government(MSIT) (No. RS-2022-II0951, Development of Uncertainty-Aware Agents Learning by Asking Questions, 90%) and Institute of Civil Military
169 | Technology Cooperation funded by the Defense Acquisition Program Administration and Ministry of Trade, Industry and Energy of Korean government under grant No. 22-CM-GU-08, 10%.
170 |
171 | ### 🌟 License
172 | The source code of this repository is released only for academic use. See the [license](LICENSE) file for details.
173 |
174 | ### 📚 Citation
175 | If you use CUFIT in your research, please consider citing us.
176 | ```bibtex
177 | @inproceedings{
178 | yu2024curriculum,
179 | title={Curriculum Fine-tuning of Vision Foundation Model for Medical Image Classification Under Label Noise},
180 | author={Yeonguk Yu and Minhwan Ko and Sungho Shin and Kangmin Kim and Kyoobin Lee},
181 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
182 | year={2024},
183 | url={https://openreview.net/forum?id=vYUx8j5KK2}
184 | }
185 | ```
186 |
--------------------------------------------------------------------------------
/assets/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/intro.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/logo.png
--------------------------------------------------------------------------------
/assets/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gist-ailab/CUFIT/5a521cdaa41a326962ebb4d20d6a79142f33ac4d/assets/poster.png
--------------------------------------------------------------------------------
/conf/aptos.json:
--------------------------------------------------------------------------------
1 | {
2 | "epoch" : "100",
3 | "id_dataset" : "../data/APTOS-2019",
4 | "batch_size" : 32,
5 | "save_path" : "./checkpoints/APTOS2019/",
6 | "num_classes" : 5
7 | }
--------------------------------------------------------------------------------
/conf/bloodmnist.json:
--------------------------------------------------------------------------------
1 | {
2 | "epoch" : "100",
3 | "id_dataset" : "./data/bloodmnist",
4 | "batch_size" : 32,
5 | "save_path" : "./checkpoints/BLOODMNIST/",
6 | "num_classes" : 8
7 | }
--------------------------------------------------------------------------------
/conf/ham10000.json:
--------------------------------------------------------------------------------
1 | {
2 | "epoch" : "100",
3 | "id_dataset" : "./data/ham10000",
4 | "batch_size" : 32,
5 | "save_path" : "./checkpoints/HAM10000/",
6 | "num_classes" : 7
7 | }
--------------------------------------------------------------------------------
/conf/organcmnist.json:
--------------------------------------------------------------------------------
1 | {
2 | "epoch" : "100",
3 | "id_dataset" : "./data/organcmnist",
4 | "batch_size" : 32,
5 | "save_path" : "./checkpoints/ORGANCMNIST/",
6 | "num_classes" : 11
7 | }
--------------------------------------------------------------------------------
/data/ham10000.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def create_folder(folder_name):
5 | folder = os.path.join('ham10000', folder_name)
6 | class_list = ['MEL', 'NV','BCC','AKIEC','BKL','DF','VASC']
7 |
8 | os.mkdir(folder)
9 | for c in class_list:
10 | folder = os.path.join('ham10000', folder_name, c)
11 | os.mkdir(folder)
12 |
13 | def read_and_move(csv_file, img_folder, is_train=True):
14 | class_list = ['MEL', 'NV','BCC','AKIEC','BKL','DF','VASC']
15 |
16 | if is_train:
17 | dst_folder = 'ham10000/train/'
18 | else:
19 | dst_folder = 'ham10000/test/'
20 |
21 |
22 | with open(csv_file, 'r') as f:
23 | lines = f.readlines()[1:]
24 |
25 | for line in lines:
26 | items = line.split(',')
27 | items[-1] = items[-1].replace('\n', '')
28 | class_info = [float(x) for x in items[1:]]
29 | class_info = class_info.index(1.0)
30 |
31 | img = '{}.jpg'.format(items[0])
32 | label = class_info
33 |
34 | img_path_src = os.path.join(img_folder, img)
35 | img_path_dst = os.path.join(dst_folder, class_list[label], img)
36 |
37 | shutil.copy2(img_path_src, img_path_dst)
38 |
39 |
40 | if __name__ == '__main__':
41 | train_image_folder = 'ISIC2018_Task3_Training_Input'
42 | train_image_gt_csv = 'ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv'
43 |
44 | test_image_folder = 'ISIC2018_Task3_Test_Input'
45 | test_image_gt_csv = 'ISIC2018_Task3_Test_GroundTruth/ISIC2018_Task3_Test_GroundTruth.csv'
46 |
47 | os.mkdir('./ham10000')
48 | # Create train folder
49 | create_folder('train')
50 | read_and_move(train_image_gt_csv, train_image_folder)
51 |
52 |
53 | # Create test folder
54 | create_folder('test')
55 | read_and_move(test_image_gt_csv, test_image_folder, is_train=False)
56 |
--------------------------------------------------------------------------------
/dino_variant.py:
--------------------------------------------------------------------------------
1 |
2 | _small_variant = dict(
3 | patch_size=14,
4 | embed_dim=384,
5 | depth=12,
6 | num_heads=6,
7 | mlp_ratio=4,
8 | img_size=518,
9 | ffn_layer="mlp",
10 | init_values=1e-05,
11 | block_chunks=0,
12 | qkv_bias=True,
13 | proj_bias=True,
14 | ffn_bias=True
15 | )
16 | _small_dino = 'dinov2_vits14'
17 |
18 | _base_variant = dict(
19 | patch_size=14,
20 | embed_dim=768,
21 | depth=12,
22 | num_heads=12,
23 | mlp_ratio=4,
24 | img_size=518,
25 | ffn_layer="mlp",
26 | init_values=1e-05,
27 | block_chunks=0,
28 | qkv_bias=True,
29 | proj_bias=True,
30 | ffn_bias=True,
31 | out_indices = [7, 11, 14, 17]
32 | )
33 | _base_dino = 'dinov2_vitb14'
34 |
35 | _large_variant = dict(
36 | patch_size=14,
37 | embed_dim=1024,
38 | depth=24,
39 | num_heads=16,
40 | mlp_ratio=4,
41 | img_size=518,
42 | ffn_layer="mlp",
43 | init_values=1e-05,
44 | block_chunks=0,
45 | qkv_bias=True,
46 | proj_bias=True,
47 | ffn_bias=True
48 | )
49 | _large_dino = 'dinov2_vitl14'
50 |
--------------------------------------------------------------------------------
/rein/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 |
--------------------------------------------------------------------------------
/rein/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbones import ReinsDinoVisionTransformer, ReinsDinoVisionTransformer_3_head
2 | from .backbones import ReinsResNet
3 |
--------------------------------------------------------------------------------
/rein/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # from .dino_v2 import DinoVisionTransformer
2 | from .reins_dinov2 import ReinsDinoVisionTransformer, ReinsDinoVisionTransformer_3_head
3 | from .reins_resnet import ReinsResNet
4 | # from .reins_eva_02 import ReinsEVA2
5 | # from .clip import CLIPVisionTransformer
6 |
7 | __all__ = [
8 | "CLIPVisionTransformer",
9 | "DinoVisionTransformer",
10 | "ReinsDinoVisionTransformer",
11 | "ReinsDinoVisionTransformer_3_head",
12 | "ReinsEVA2",
13 | ]
14 |
--------------------------------------------------------------------------------
/rein/models/backbones/beit.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, mmseg, setr, xcit and swin code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/fudan-zvg/SETR
10 | # https://github.com/facebookresearch/xcit/
11 | # https://github.com/microsoft/Swin-Transformer
12 | # --------------------------------------------------------'
13 | import math
14 | from functools import partial
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.utils.checkpoint as cp
20 | # from mmseg.models.builder import BACKBONES
21 |
22 | # from mmengine.logging import MMLogger
23 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
24 |
25 | # Copyright (c) Open-MMLab. All rights reserved.
26 | import io
27 | import math
28 | import os
29 | import os.path as osp
30 | import pkgutil
31 | import time
32 | import warnings
33 | from collections import OrderedDict
34 | from importlib import import_module
35 | from tempfile import TemporaryDirectory
36 |
37 | import mmcv
38 | import numpy as np
39 | import torch
40 | import torchvision
41 | from mmengine.fileio import FileClient
42 | from mmengine.fileio import load as load_file
43 | from mmengine.dist import get_dist_info
44 | from mmengine.model import is_model_wrapper
45 | from mmengine import mkdir_or_exist
46 | from scipy import interpolate
47 | from torch.nn import functional as F
48 | from torch.optim import Optimizer
49 | from torch.utils import model_zoo
50 |
51 | ENV_MMCV_HOME = "MMCV_HOME"
52 | ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
53 | DEFAULT_CACHE_DIR = "~/.cache"
54 |
55 |
56 | def _get_mmcv_home():
57 | mmcv_home = os.path.expanduser(
58 | os.getenv(
59 | ENV_MMCV_HOME,
60 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "mmcv"),
61 | )
62 | )
63 |
64 | mkdir_or_exist(mmcv_home)
65 | return mmcv_home
66 |
67 |
68 | def load_state_dict(module, state_dict, strict=False, logger=None):
69 | """Load state_dict to a module.
70 |
71 | This method is modified from :meth:`torch.nn.Module.load_state_dict`.
72 | Default value for ``strict`` is set to ``False`` and the message for
73 | param mismatch will be shown even if strict is False.
74 | Args:
75 | module (Module): Module that receives the state_dict.
76 | state_dict (OrderedDict): Weights.
77 | strict (bool): whether to strictly enforce that the keys
78 | in :attr:`state_dict` match the keys returned by this module's
79 | :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
80 | logger (:obj:`logging.Logger`, optional): Logger to log the error
81 | message. If not specified, print function will be used.
82 | """
83 | unexpected_keys = []
84 | all_missing_keys = []
85 | err_msg = []
86 |
87 | metadata = getattr(state_dict, "_metadata", None)
88 | state_dict = state_dict.copy()
89 | if metadata is not None:
90 | state_dict._metadata = metadata
91 |
92 | # use _load_from_state_dict to enable checkpoint version control
93 | def load(module, prefix=""):
94 | # recursively check parallel module in case that the model has a
95 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
96 | if is_model_wrapper(module):
97 | module = module.module
98 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
99 | module._load_from_state_dict(
100 | state_dict,
101 | prefix,
102 | local_metadata,
103 | True,
104 | all_missing_keys,
105 | unexpected_keys,
106 | err_msg,
107 | )
108 | for name, child in module._modules.items():
109 | if child is not None:
110 | load(child, prefix + name + ".")
111 |
112 | load(module)
113 | load = None # break load->load reference cycle
114 |
115 | # ignore "num_batches_tracked" of BN layers
116 | missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
117 |
118 | if unexpected_keys:
119 | err_msg.append(
120 | "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n'
121 | )
122 | if missing_keys:
123 | err_msg.append(
124 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n'
125 | )
126 |
127 | rank, _ = get_dist_info()
128 | if len(err_msg) > 0 and rank == 0:
129 | err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
130 | err_msg = "\n".join(err_msg)
131 | if strict:
132 | raise RuntimeError(err_msg)
133 | elif logger is not None:
134 | logger.warning(err_msg)
135 | else:
136 | print(err_msg)
137 |
138 |
139 | def load_url_dist(url, model_dir=None, map_location="cpu"):
140 | """In distributed setting, this function only download checkpoint at local
141 | rank 0."""
142 | rank, world_size = get_dist_info()
143 | rank = int(os.environ.get("LOCAL_RANK", rank))
144 | if rank == 0:
145 | checkpoint = model_zoo.load_url(
146 | url, model_dir=model_dir, map_location=map_location
147 | )
148 | if world_size > 1:
149 | torch.distributed.barrier()
150 | if rank > 0:
151 | checkpoint = model_zoo.load_url(
152 | url, model_dir=model_dir, map_location=map_location
153 | )
154 | return checkpoint
155 |
156 |
157 | def load_pavimodel_dist(model_path, map_location=None):
158 | """In distributed setting, this function only download checkpoint at local
159 | rank 0."""
160 | try:
161 | from pavi import modelcloud
162 | except ImportError:
163 | raise ImportError("Please install pavi to load checkpoint from modelcloud.")
164 | rank, world_size = get_dist_info()
165 | rank = int(os.environ.get("LOCAL_RANK", rank))
166 | if rank == 0:
167 | model = modelcloud.get(model_path)
168 | with TemporaryDirectory() as tmp_dir:
169 | downloaded_file = osp.join(tmp_dir, model.name)
170 | model.download(downloaded_file)
171 | checkpoint = torch.load(downloaded_file, map_location=map_location)
172 | if world_size > 1:
173 | torch.distributed.barrier()
174 | if rank > 0:
175 | model = modelcloud.get(model_path)
176 | with TemporaryDirectory() as tmp_dir:
177 | downloaded_file = osp.join(tmp_dir, model.name)
178 | model.download(downloaded_file)
179 | checkpoint = torch.load(downloaded_file, map_location=map_location)
180 | return checkpoint
181 |
182 |
183 | def load_fileclient_dist(filename, backend, map_location):
184 | """In distributed setting, this function only download checkpoint at local
185 | rank 0."""
186 | rank, world_size = get_dist_info()
187 | rank = int(os.environ.get("LOCAL_RANK", rank))
188 | allowed_backends = ["ceph"]
189 | if backend not in allowed_backends:
190 | raise ValueError(f"Load from Backend {backend} is not supported.")
191 | if rank == 0:
192 | fileclient = FileClient(backend=backend)
193 | buffer = io.BytesIO(fileclient.get(filename))
194 | checkpoint = torch.load(buffer, map_location=map_location)
195 | if world_size > 1:
196 | torch.distributed.barrier()
197 | if rank > 0:
198 | fileclient = FileClient(backend=backend)
199 | buffer = io.BytesIO(fileclient.get(filename))
200 | checkpoint = torch.load(buffer, map_location=map_location)
201 | return checkpoint
202 |
203 |
204 | def get_torchvision_models():
205 | model_urls = dict()
206 | for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
207 | if ispkg:
208 | continue
209 | _zoo = import_module(f"torchvision.models.{name}")
210 | if hasattr(_zoo, "model_urls"):
211 | _urls = getattr(_zoo, "model_urls")
212 | model_urls.update(_urls)
213 | return model_urls
214 |
215 |
216 | def get_external_models():
217 | mmcv_home = _get_mmcv_home()
218 | default_json_path = osp.join(mmcv.__path__[0], "model_zoo/open_mmlab.json")
219 | default_urls = load_file(default_json_path)
220 | assert isinstance(default_urls, dict)
221 | external_json_path = osp.join(mmcv_home, "open_mmlab.json")
222 | if osp.exists(external_json_path):
223 | external_urls = load_file(external_json_path)
224 | assert isinstance(external_urls, dict)
225 | default_urls.update(external_urls)
226 |
227 | return default_urls
228 |
229 |
230 | def get_mmcls_models():
231 | mmcls_json_path = osp.join(mmcv.__path__[0], "model_zoo/mmcls.json")
232 | mmcls_urls = load_file(mmcls_json_path)
233 |
234 | return mmcls_urls
235 |
236 |
237 | def get_deprecated_model_names():
238 | deprecate_json_path = osp.join(mmcv.__path__[0], "model_zoo/deprecated.json")
239 | deprecate_urls = load_file(deprecate_json_path)
240 | assert isinstance(deprecate_urls, dict)
241 |
242 | return deprecate_urls
243 |
244 |
245 | def _process_mmcls_checkpoint(checkpoint):
246 | state_dict = checkpoint["state_dict"]
247 | new_state_dict = OrderedDict()
248 | for k, v in state_dict.items():
249 | if k.startswith("backbone."):
250 | new_state_dict[k[9:]] = v
251 | new_checkpoint = dict(state_dict=new_state_dict)
252 |
253 | return new_checkpoint
254 |
255 |
256 | def _load_checkpoint(filename, map_location=None):
257 | """Load checkpoint from somewhere (modelzoo, file, url).
258 |
259 | Args:
260 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
261 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
262 | details.
263 | map_location (str | None): Same as :func:`torch.load`. Default: None.
264 | Returns:
265 | dict | OrderedDict: The loaded checkpoint. It can be either an
266 | OrderedDict storing model weights or a dict containing other
267 | information, which depends on the checkpoint.
268 | """
269 | if filename.startswith("modelzoo://"):
270 | warnings.warn(
271 | 'The URL scheme of "modelzoo://" is deprecated, please '
272 | 'use "torchvision://" instead'
273 | )
274 | model_urls = get_torchvision_models()
275 | model_name = filename[11:]
276 | checkpoint = load_url_dist(model_urls[model_name])
277 | elif filename.startswith("torchvision://"):
278 | model_urls = get_torchvision_models()
279 | model_name = filename[14:]
280 | checkpoint = load_url_dist(model_urls[model_name])
281 | elif filename.startswith("open-mmlab://"):
282 | model_urls = get_external_models()
283 | model_name = filename[13:]
284 | deprecated_urls = get_deprecated_model_names()
285 | if model_name in deprecated_urls:
286 | warnings.warn(
287 | f"open-mmlab://{model_name} is deprecated in favor "
288 | f"of open-mmlab://{deprecated_urls[model_name]}"
289 | )
290 | model_name = deprecated_urls[model_name]
291 | model_url = model_urls[model_name]
292 | # check if is url
293 | if model_url.startswith(("http://", "https://")):
294 | checkpoint = load_url_dist(model_url)
295 | else:
296 | filename = osp.join(_get_mmcv_home(), model_url)
297 | if not osp.isfile(filename):
298 | raise IOError(f"{filename} is not a checkpoint file")
299 | checkpoint = torch.load(filename, map_location=map_location)
300 | elif filename.startswith("mmcls://"):
301 | model_urls = get_mmcls_models()
302 | model_name = filename[8:]
303 | checkpoint = load_url_dist(model_urls[model_name])
304 | checkpoint = _process_mmcls_checkpoint(checkpoint)
305 | elif filename.startswith(("http://", "https://")):
306 | checkpoint = load_url_dist(filename)
307 | elif filename.startswith("pavi://"):
308 | model_path = filename[7:]
309 | checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
310 | elif filename.startswith("s3://"):
311 | checkpoint = load_fileclient_dist(
312 | filename, backend="ceph", map_location=map_location
313 | )
314 | else:
315 | if not osp.isfile(filename):
316 | raise IOError(f"{filename} is not a checkpoint file")
317 | checkpoint = torch.load(filename, map_location=map_location)
318 | return checkpoint
319 |
320 |
321 | def cosine_scheduler(
322 | base_value,
323 | final_value,
324 | epochs,
325 | niter_per_ep,
326 | warmup_epochs=0,
327 | start_warmup_value=0,
328 | warmup_steps=-1,
329 | ):
330 | warmup_schedule = np.array([])
331 | warmup_iters = warmup_epochs * niter_per_ep
332 | if warmup_steps > 0:
333 | warmup_iters = warmup_steps
334 | print("Set warmup steps = %d" % warmup_iters)
335 | if warmup_epochs > 0:
336 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
337 |
338 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
339 | schedule = np.array(
340 | [
341 | final_value
342 | + 0.5
343 | * (base_value - final_value)
344 | * (1 + math.cos(math.pi * i / (len(iters))))
345 | for i in iters
346 | ]
347 | )
348 |
349 | schedule = np.concatenate((warmup_schedule, schedule))
350 |
351 | assert len(schedule) == epochs * niter_per_ep
352 | return schedule
353 |
354 |
355 | def load_checkpoint(model, filename, map_location="cpu", strict=False, logger=None):
356 | """Load checkpoint from a file or URI.
357 |
358 | Args:
359 | model (Module): Module to load checkpoint.
360 | filename (str): Accept local filepath, URL, ``torchvision://xxx``,
361 | ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
362 | details.
363 | map_location (str): Same as :func:`torch.load`.
364 | strict (bool): Whether to allow different params for the model and
365 | checkpoint.
366 | logger (:mod:`logging.Logger` or None): The logger for error message.
367 | Returns:
368 | dict or OrderedDict: The loaded checkpoint.
369 | """
370 | checkpoint = _load_checkpoint(filename, map_location)
371 | # OrderedDict is a subclass of dict
372 | if not isinstance(checkpoint, dict):
373 | raise RuntimeError(f"No state_dict found in checkpoint file {filename}")
374 | # get state_dict from checkpoint
375 | if "state_dict" in checkpoint:
376 | state_dict = checkpoint["state_dict"]
377 | elif "model" in checkpoint:
378 | state_dict = checkpoint["model"]
379 | elif "module" in checkpoint:
380 | state_dict = checkpoint["module"]
381 | else:
382 | state_dict = checkpoint
383 | # strip prefix of state_dict
384 | if list(state_dict.keys())[0].startswith("module."):
385 | state_dict = {k[7:]: v for k, v in state_dict.items()}
386 |
387 | # for MoBY, load model of online branch
388 | if sorted(list(state_dict.keys()))[0].startswith("encoder"):
389 | state_dict = {
390 | k.replace("encoder.", ""): v
391 | for k, v in state_dict.items()
392 | if k.startswith("encoder.")
393 | }
394 |
395 | # reshape absolute position embedding for Swin
396 | if state_dict.get("absolute_pos_embed") is not None:
397 | absolute_pos_embed = state_dict["absolute_pos_embed"]
398 | N1, L, C1 = absolute_pos_embed.size()
399 | N2, C2, H, W = model.absolute_pos_embed.size()
400 | if N1 != N2 or C1 != C2 or L != H * W:
401 | logger.warning("Error in loading absolute_pos_embed, pass")
402 | else:
403 | state_dict["absolute_pos_embed"] = absolute_pos_embed.view(
404 | N2, H, W, C2
405 | ).permute(0, 3, 1, 2)
406 |
407 | rank, _ = get_dist_info()
408 | if "rel_pos_bias.relative_position_bias_table" in state_dict:
409 | if rank == 0:
410 | print("Expand the shared relative position embedding to each layers. ")
411 | num_layers = model.get_num_layers()
412 | rel_pos_bias = state_dict["rel_pos_bias.relative_position_bias_table"]
413 | for i in range(num_layers):
414 | state_dict[
415 | "blocks.%d.attn.relative_position_bias_table" % i
416 | ] = rel_pos_bias.clone()
417 |
418 | state_dict.pop("rel_pos_bias.relative_position_bias_table")
419 |
420 | all_keys = list(state_dict.keys())
421 | for key in all_keys:
422 | if "relative_position_index" in key:
423 | state_dict.pop(key)
424 |
425 | if "relative_position_bias_table" in key:
426 | rel_pos_bias = state_dict[key]
427 | src_num_pos, num_attn_heads = rel_pos_bias.size()
428 | dst_num_pos, _ = model.state_dict()[key].size()
429 | dst_patch_shape = model.patch_embed.patch_shape
430 | if dst_patch_shape[0] != dst_patch_shape[1]:
431 | raise NotImplementedError()
432 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
433 | dst_patch_shape[1] * 2 - 1
434 | )
435 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
436 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
437 | if src_size != dst_size:
438 | if rank == 0:
439 | print(
440 | "Position interpolate for %s from %dx%d to %dx%d"
441 | % (key, src_size, src_size, dst_size, dst_size)
442 | )
443 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
444 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
445 |
446 | def geometric_progression(a, r, n):
447 | return a * (1.0 - r**n) / (1.0 - r)
448 |
449 | left, right = 1.01, 1.5
450 | while right - left > 1e-6:
451 | q = (left + right) / 2.0
452 | gp = geometric_progression(1, q, src_size // 2)
453 | if gp > dst_size // 2:
454 | right = q
455 | else:
456 | left = q
457 |
458 | # if q > 1.13492:
459 | # q = 1.13492
460 |
461 | dis = []
462 | cur = 1
463 | for i in range(src_size // 2):
464 | dis.append(cur)
465 | cur += q ** (i + 1)
466 |
467 | r_ids = [-_ for _ in reversed(dis)]
468 |
469 | x = r_ids + [0] + dis
470 | y = r_ids + [0] + dis
471 |
472 | t = dst_size // 2.0
473 | dx = np.arange(-t, t + 0.1, 1.0)
474 | dy = np.arange(-t, t + 0.1, 1.0)
475 | if rank == 0:
476 | print("x = {}".format(x))
477 | print("dx = {}".format(dx))
478 |
479 | all_rel_pos_bias = []
480 |
481 | for i in range(num_attn_heads):
482 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
483 | f = interpolate.interp2d(x, y, z, kind="cubic")
484 | all_rel_pos_bias.append(
485 | torch.Tensor(f(dx, dy))
486 | .contiguous()
487 | .view(-1, 1)
488 | .to(rel_pos_bias.device)
489 | )
490 |
491 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
492 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
493 | state_dict[key] = new_rel_pos_bias
494 |
495 | if "pos_embed" in state_dict:
496 | pos_embed_checkpoint = state_dict["pos_embed"]
497 | embedding_size = pos_embed_checkpoint.shape[-1]
498 | num_patches = model.patch_embed.num_patches
499 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
500 | # height (== width) for the checkpoint position embedding
501 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
502 | # height (== width) for the new position embedding
503 | new_size = int(num_patches**0.5)
504 | # class_token and dist_token are kept unchanged
505 | if orig_size != new_size:
506 | if rank == 0:
507 | print(
508 | "Position interpolate from %dx%d to %dx%d"
509 | % (orig_size, orig_size, new_size, new_size)
510 | )
511 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
512 | # only the position tokens are interpolated
513 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
514 | pos_tokens = pos_tokens.reshape(
515 | -1, orig_size, orig_size, embedding_size
516 | ).permute(0, 3, 1, 2)
517 | pos_tokens = torch.nn.functional.interpolate(
518 | pos_tokens,
519 | size=(new_size, new_size),
520 | mode="bicubic",
521 | align_corners=False,
522 | )
523 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
524 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
525 | state_dict["pos_embed"] = new_pos_embed
526 |
527 | # interpolate position bias table if needed
528 | relative_position_bias_table_keys = [
529 | k for k in state_dict.keys() if "relative_position_bias_table" in k
530 | ]
531 | for table_key in relative_position_bias_table_keys:
532 | table_pretrained = state_dict[table_key]
533 | table_current = model.state_dict()[table_key]
534 | L1, nH1 = table_pretrained.size()
535 | L2, nH2 = table_current.size()
536 | if nH1 != nH2:
537 | logger.warning(f"Error in loading {table_key}, pass")
538 | else:
539 | if L1 != L2:
540 | S1 = int(L1**0.5)
541 | S2 = int(L2**0.5)
542 | table_pretrained_resized = F.interpolate(
543 | table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
544 | size=(S2, S2),
545 | mode="bicubic",
546 | )
547 | state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(
548 | 1, 0
549 | )
550 |
551 | # load state_dict
552 | load_state_dict(model, state_dict, strict, logger)
553 | return checkpoint
554 |
555 |
556 | def weights_to_cpu(state_dict):
557 | """Copy a model state_dict to cpu.
558 |
559 | Args:
560 | state_dict (OrderedDict): Model weights on GPU.
561 | Returns:
562 | OrderedDict: Model weights on GPU.
563 | """
564 | state_dict_cpu = OrderedDict()
565 | for key, val in state_dict.items():
566 | state_dict_cpu[key] = val.cpu()
567 | return state_dict_cpu
568 |
569 |
570 | def _save_to_state_dict(module, destination, prefix, keep_vars):
571 | """Saves module state to `destination` dictionary.
572 |
573 | This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
574 | Args:
575 | module (nn.Module): The module to generate state_dict.
576 | destination (dict): A dict where state will be stored.
577 | prefix (str): The prefix for parameters and buffers used in this
578 | module.
579 | """
580 | for name, param in module._parameters.items():
581 | if param is not None:
582 | destination[prefix + name] = param if keep_vars else param.detach()
583 | for name, buf in module._buffers.items():
584 | # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
585 | if buf is not None:
586 | destination[prefix + name] = buf if keep_vars else buf.detach()
587 |
588 |
589 | def get_state_dict(module, destination=None, prefix="", keep_vars=False):
590 | """Returns a dictionary containing a whole state of the module.
591 |
592 | Both parameters and persistent buffers (e.g. running averages) are
593 | included. Keys are corresponding parameter and buffer names.
594 | This method is modified from :meth:`torch.nn.Module.state_dict` to
595 | recursively check parallel module in case that the model has a complicated
596 | structure, e.g., nn.Module(nn.Module(DDP)).
597 | Args:
598 | module (nn.Module): The module to generate state_dict.
599 | destination (OrderedDict): Returned dict for the state of the
600 | module.
601 | prefix (str): Prefix of the key.
602 | keep_vars (bool): Whether to keep the variable property of the
603 | parameters. Default: False.
604 | Returns:
605 | dict: A dictionary containing a whole state of the module.
606 | """
607 | # recursively check parallel module in case that the model has a
608 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
609 | if is_model_wrapper(module):
610 | module = module.module
611 |
612 | # below is the same as torch.nn.Module.state_dict()
613 | if destination is None:
614 | destination = OrderedDict()
615 | destination._metadata = OrderedDict()
616 | destination._metadata[prefix[:-1]] = local_metadata = dict(version=module._version)
617 | _save_to_state_dict(module, destination, prefix, keep_vars)
618 | for name, child in module._modules.items():
619 | if child is not None:
620 | get_state_dict(child, destination, prefix + name + ".", keep_vars=keep_vars)
621 | for hook in module._state_dict_hooks.values():
622 | hook_result = hook(module, destination, prefix, local_metadata)
623 | if hook_result is not None:
624 | destination = hook_result
625 | return destination
626 |
627 |
628 | def save_checkpoint(model, filename, optimizer=None, meta=None):
629 | """Save checkpoint to file.
630 |
631 | The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
632 | ``optimizer``. By default ``meta`` will contain version and time info.
633 | Args:
634 | model (Module): Module whose params are to be saved.
635 | filename (str): Checkpoint filename.
636 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
637 | meta (dict, optional): Metadata to be saved in checkpoint.
638 | """
639 | if meta is None:
640 | meta = {}
641 | elif not isinstance(meta, dict):
642 | raise TypeError(f"meta must be a dict or None, but got {type(meta)}")
643 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
644 |
645 | if is_model_wrapper(model):
646 | model = model.module
647 |
648 | if hasattr(model, "CLASSES") and model.CLASSES is not None:
649 | # save class name to the meta
650 | meta.update(CLASSES=model.CLASSES)
651 |
652 | checkpoint = {"meta": meta, "state_dict": weights_to_cpu(get_state_dict(model))}
653 | # save optimizer state dict in the checkpoint
654 | if isinstance(optimizer, Optimizer):
655 | checkpoint["optimizer"] = optimizer.state_dict()
656 | elif isinstance(optimizer, dict):
657 | checkpoint["optimizer"] = {}
658 | for name, optim in optimizer.items():
659 | checkpoint["optimizer"][name] = optim.state_dict()
660 |
661 | if filename.startswith("pavi://"):
662 | try:
663 | from pavi import modelcloud
664 | from pavi.exception import NodeNotFoundError
665 | except ImportError:
666 | raise ImportError("Please install pavi to load checkpoint from modelcloud.")
667 | model_path = filename[7:]
668 | root = modelcloud.Folder()
669 | model_dir, model_name = osp.split(model_path)
670 | try:
671 | model = modelcloud.get(model_dir)
672 | except NodeNotFoundError:
673 | model = root.create_training_model(model_dir)
674 | with TemporaryDirectory() as tmp_dir:
675 | checkpoint_file = osp.join(tmp_dir, model_name)
676 | with open(checkpoint_file, "wb") as f:
677 | torch.save(checkpoint, f)
678 | f.flush()
679 | model.create_file(checkpoint_file, name=model_name)
680 | else:
681 | mmcv.mkdir_or_exist(osp.dirname(filename))
682 | # immediately flush buffer
683 | with open(filename, "wb") as f:
684 | torch.save(checkpoint, f)
685 | f.flush()
686 |
687 |
688 | class DropPath(nn.Module):
689 | """Drop paths (Stochastic Depth) per sample (when applied in main path of
690 | residual blocks)."""
691 |
692 | def __init__(self, drop_prob=None):
693 | super(DropPath, self).__init__()
694 | self.drop_prob = drop_prob
695 |
696 | def forward(self, x):
697 | return drop_path(x, self.drop_prob, self.training)
698 |
699 | def extra_repr(self) -> str:
700 | return "p={}".format(self.drop_prob)
701 |
702 |
703 | class Mlp(nn.Module):
704 | def __init__(
705 | self,
706 | in_features,
707 | hidden_features=None,
708 | out_features=None,
709 | act_layer=nn.GELU,
710 | drop=0.0,
711 | ):
712 | super().__init__()
713 | out_features = out_features or in_features
714 | hidden_features = hidden_features or in_features
715 | self.fc1 = nn.Linear(in_features, hidden_features)
716 | self.act = act_layer()
717 | self.fc2 = nn.Linear(hidden_features, out_features)
718 | self.drop = nn.Dropout(drop)
719 |
720 | def forward(self, x):
721 | x = self.fc1(x)
722 | x = self.act(x)
723 | # x = self.drop(x)
724 | # commit this for the original BERT implement
725 | x = self.fc2(x)
726 | x = self.drop(x)
727 | return x
728 |
729 |
730 | class Attention(nn.Module):
731 | def __init__(
732 | self,
733 | dim,
734 | num_heads=8,
735 | qkv_bias=False,
736 | qk_scale=None,
737 | attn_drop=0.0,
738 | proj_drop=0.0,
739 | window_size=None,
740 | attn_head_dim=None,
741 | ):
742 | super().__init__()
743 | self.num_heads = num_heads
744 | head_dim = dim // num_heads
745 | if attn_head_dim is not None:
746 | head_dim = attn_head_dim
747 | all_head_dim = head_dim * self.num_heads
748 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
749 | self.scale = qk_scale or head_dim**-0.5
750 |
751 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
752 | if qkv_bias:
753 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
754 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
755 | else:
756 | self.q_bias = None
757 | self.v_bias = None
758 |
759 | if window_size:
760 | self.window_size = window_size
761 | self.num_relative_distance = (2 * window_size[0] - 1) * (
762 | 2 * window_size[1] - 1
763 | ) + 3
764 | self.relative_position_bias_table = nn.Parameter(
765 | torch.zeros(self.num_relative_distance, num_heads)
766 | ) # 2*Wh-1 * 2*Ww-1, nH
767 | # cls to token & token 2 cls & cls to cls
768 |
769 | # get pair-wise relative position index for each token inside the window
770 | coords_h = torch.arange(window_size[0])
771 | coords_w = torch.arange(window_size[1])
772 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
773 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
774 | relative_coords = (
775 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
776 | ) # 2, Wh*Ww, Wh*Ww
777 | relative_coords = relative_coords.permute(
778 | 1, 2, 0
779 | ).contiguous() # Wh*Ww, Wh*Ww, 2
780 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
781 | relative_coords[:, :, 1] += window_size[1] - 1
782 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
783 | relative_position_index = torch.zeros(
784 | size=(window_size[0] * window_size[1] + 1,) * 2,
785 | dtype=relative_coords.dtype,
786 | )
787 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
788 | relative_position_index[0, 0:] = self.num_relative_distance - 3
789 | relative_position_index[0:, 0] = self.num_relative_distance - 2
790 | relative_position_index[0, 0] = self.num_relative_distance - 1
791 | self.register_buffer("relative_position_index", relative_position_index)
792 |
793 | # trunc_normal_(self.relative_position_bias_table, std=.0)
794 | else:
795 | self.window_size = None
796 | self.relative_position_bias_table = None
797 | self.relative_position_index = None
798 |
799 | self.attn_drop = nn.Dropout(attn_drop)
800 | self.proj = nn.Linear(all_head_dim, dim)
801 | self.proj_drop = nn.Dropout(proj_drop)
802 |
803 | def forward(self, x, rel_pos_bias=None):
804 | B, N, C = x.shape
805 | qkv_bias = None
806 | if self.q_bias is not None:
807 | qkv_bias = torch.cat(
808 | (
809 | self.q_bias,
810 | torch.zeros_like(self.v_bias, requires_grad=False),
811 | self.v_bias,
812 | )
813 | )
814 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
815 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
816 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
817 | #qkv: B,N,3,K,C->3,B,K,N,C
818 | q, k, v = (
819 | qkv[0],
820 | qkv[1],
821 | qkv[2],
822 | ) # make torchscript happy (cannot use tensor as tuple)
823 |
824 | q = q * self.scale
825 | attn = q @ k.transpose(-2, -1)
826 | # attn : B,K,N,C@B,K,N,C->B,K,N,N
827 |
828 | if self.relative_position_bias_table is not None:
829 | relative_position_bias = self.relative_position_bias_table[
830 | self.relative_position_index.view(-1)
831 | ].view(
832 | self.window_size[0] * self.window_size[1] + 1,
833 | self.window_size[0] * self.window_size[1] + 1,
834 | -1,
835 | ) # Wh*Ww,Wh*Ww,nH
836 | relative_position_bias = relative_position_bias.permute(
837 | 2, 0, 1
838 | ).contiguous() # nH, Wh*Ww, Wh*Ww
839 | # relative_position_bias = relative_position_bias[:, 1:, 1:]
840 | attn = attn + relative_position_bias.unsqueeze(0)
841 |
842 | if rel_pos_bias is not None:
843 | attn = attn + rel_pos_bias
844 |
845 | attn = attn.softmax(dim=-1)
846 | attn = self.attn_drop(attn)
847 |
848 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
849 | x = self.proj(x)
850 | x = self.proj_drop(x)
851 | return x
852 |
853 |
854 | class Block(nn.Module):
855 | def __init__(
856 | self,
857 | dim,
858 | num_heads,
859 | mlp_ratio=4.0,
860 | qkv_bias=False,
861 | qk_scale=None,
862 | drop=0.0,
863 | attn_drop=0.0,
864 | drop_path=0.0,
865 | init_values=None,
866 | act_layer=nn.GELU,
867 | norm_layer=nn.LayerNorm,
868 | window_size=None,
869 | attn_head_dim=None,
870 | with_cp=False,
871 | ):
872 | super().__init__()
873 | self.with_cp = with_cp
874 | self.norm1 = norm_layer(dim)
875 | self.attn = Attention(
876 | dim,
877 | num_heads=num_heads,
878 | qkv_bias=qkv_bias,
879 | qk_scale=qk_scale,
880 | attn_drop=attn_drop,
881 | proj_drop=drop,
882 | window_size=window_size,
883 | attn_head_dim=attn_head_dim,
884 | )
885 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
886 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
887 | self.norm2 = norm_layer(dim)
888 | mlp_hidden_dim = int(dim * mlp_ratio)
889 | self.mlp = Mlp(
890 | in_features=dim,
891 | hidden_features=mlp_hidden_dim,
892 | act_layer=act_layer,
893 | drop=drop,
894 | )
895 |
896 | if init_values is not None:
897 | self.gamma_1 = nn.Parameter(
898 | init_values * torch.ones((dim)), requires_grad=True
899 | )
900 | self.gamma_2 = nn.Parameter(
901 | init_values * torch.ones((dim)), requires_grad=True
902 | )
903 | else:
904 | self.gamma_1, self.gamma_2 = None, None
905 |
906 | def forward(self, x, H, W, rel_pos_bias=None):
907 | def _inner_forward(x):
908 | if self.gamma_1 is None:
909 | x = x + self.drop_path(
910 | self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
911 | )
912 | x = x + self.drop_path(self.mlp(self.norm2(x)))
913 | else:
914 | x = x + self.drop_path(
915 | self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
916 | )
917 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
918 | return x
919 |
920 | if self.with_cp and x.requires_grad:
921 | x = cp.checkpoint(_inner_forward, x)
922 | else:
923 | x = _inner_forward(x)
924 | return x
925 |
926 |
927 | class PatchEmbed(nn.Module):
928 | """Image to Patch Embedding"""
929 |
930 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
931 | super().__init__()
932 | img_size = to_2tuple(img_size)
933 | patch_size = to_2tuple(patch_size)
934 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
935 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
936 | self.img_size = img_size
937 | self.patch_size = patch_size
938 | self.num_patches = num_patches
939 |
940 | self.proj = nn.Conv2d(
941 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
942 | )
943 |
944 | def forward(self, x, **kwargs):
945 | B, C, H, W = x.shape
946 | # FIXME look at relaxing size constraints
947 | # assert H == self.img_size[0] and W == self.img_size[1], \
948 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
949 | x = self.proj(x)
950 | Hp, Wp = x.shape[2], x.shape[3]
951 |
952 | x = x.flatten(2).transpose(1, 2)
953 | return x, Hp, Wp
954 |
955 |
956 | class HybridEmbed(nn.Module):
957 | """CNN Feature Map Embedding
958 | Extract feature map from CNN, flatten, project to embedding dim.
959 | """
960 |
961 | def __init__(
962 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768
963 | ):
964 | super().__init__()
965 | assert isinstance(backbone, nn.Module)
966 | img_size = to_2tuple(img_size)
967 | self.img_size = img_size
968 | self.backbone = backbone
969 | if feature_size is None:
970 | with torch.no_grad():
971 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
972 | # map for all networks, the feature metadata has reliable channel and stride info, but using
973 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
974 | training = backbone.training
975 | if training:
976 | backbone.eval()
977 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[
978 | -1
979 | ]
980 | feature_size = o.shape[-2:]
981 | feature_dim = o.shape[1]
982 | backbone.train(training)
983 | else:
984 | feature_size = to_2tuple(feature_size)
985 | feature_dim = self.backbone.feature_info.channels()[-1]
986 | self.num_patches = feature_size[0] * feature_size[1]
987 | self.proj = nn.Linear(feature_dim, embed_dim)
988 |
989 | def forward(self, x):
990 | x = self.backbone(x)[-1]
991 | x = x.flatten(2).transpose(1, 2)
992 | x = self.proj(x)
993 | return x
994 |
995 |
996 | class RelativePositionBias(nn.Module):
997 | def __init__(self, window_size, num_heads):
998 | super().__init__()
999 | self.window_size = window_size
1000 | self.num_relative_distance = (2 * window_size[0] - 1) * (
1001 | 2 * window_size[1] - 1
1002 | ) + 3
1003 | self.relative_position_bias_table = nn.Parameter(
1004 | torch.zeros(self.num_relative_distance, num_heads)
1005 | ) # 2*Wh-1 * 2*Ww-1, nH
1006 | # cls to token & token 2 cls & cls to cls
1007 |
1008 | # get pair-wise relative position index for each token inside the window
1009 | coords_h = torch.arange(window_size[0])
1010 | coords_w = torch.arange(window_size[1])
1011 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1012 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1013 | relative_coords = (
1014 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
1015 | ) # 2, Wh*Ww, Wh*Ww
1016 | relative_coords = relative_coords.permute(
1017 | 1, 2, 0
1018 | ).contiguous() # Wh*Ww, Wh*Ww, 2
1019 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
1020 | relative_coords[:, :, 1] += window_size[1] - 1
1021 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
1022 | relative_position_index = torch.zeros(
1023 | size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
1024 | )
1025 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1026 | relative_position_index[0, 0:] = self.num_relative_distance - 3
1027 | relative_position_index[0:, 0] = self.num_relative_distance - 2
1028 | relative_position_index[0, 0] = self.num_relative_distance - 1
1029 |
1030 | self.register_buffer("relative_position_index", relative_position_index)
1031 |
1032 | # trunc_normal_(self.relative_position_bias_table, std=.02)
1033 |
1034 | def forward(self):
1035 | relative_position_bias = self.relative_position_bias_table[
1036 | self.relative_position_index.view(-1)
1037 | ].view(
1038 | self.window_size[0] * self.window_size[1] + 1,
1039 | self.window_size[0] * self.window_size[1] + 1,
1040 | -1,
1041 | ) # Wh*Ww,Wh*Ww,nH
1042 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
1043 |
1044 |
1045 | # @BACKBONES.register_module()
1046 | class BEiT(nn.Module):
1047 | """Vision Transformer with support for patch or hybrid CNN input stage"""
1048 |
1049 | def __init__(
1050 | self,
1051 | img_size=512,
1052 | patch_size=16,
1053 | in_chans=3,
1054 | num_classes=80,
1055 | embed_dim=768,
1056 | depth=12,
1057 | num_heads=12,
1058 | mlp_ratio=4.0,
1059 | qkv_bias=False,
1060 | qk_scale=None,
1061 | drop_rate=0.0,
1062 | attn_drop_rate=0.0,
1063 | drop_path_rate=0.0,
1064 | hybrid_backbone=None,
1065 | norm_layer=None,
1066 | init_values=None,
1067 | use_checkpoint=False,
1068 | use_abs_pos_emb=False,
1069 | use_rel_pos_bias=True,
1070 | use_shared_rel_pos_bias=False,
1071 | pretrained=None,
1072 | with_cp=False,
1073 | ):
1074 | super().__init__()
1075 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
1076 | self.norm_layer = norm_layer
1077 | self.num_classes = num_classes
1078 | self.num_features = (
1079 | self.embed_dim
1080 | ) = embed_dim # num_features for consistency with other models
1081 | self.drop_path_rate = drop_path_rate
1082 | if hybrid_backbone is not None:
1083 | self.patch_embed = HybridEmbed(
1084 | hybrid_backbone,
1085 | img_size=img_size,
1086 | in_chans=in_chans,
1087 | embed_dim=embed_dim,
1088 | )
1089 | else:
1090 | self.patch_embed = PatchEmbed(
1091 | img_size=img_size,
1092 | patch_size=patch_size,
1093 | in_chans=in_chans,
1094 | embed_dim=embed_dim,
1095 | )
1096 | num_patches = self.patch_embed.num_patches
1097 |
1098 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
1099 | # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
1100 | if use_abs_pos_emb:
1101 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
1102 | else:
1103 | self.pos_embed = None
1104 | self.pos_drop = nn.Dropout(p=drop_rate)
1105 |
1106 | if use_shared_rel_pos_bias:
1107 | self.rel_pos_bias = RelativePositionBias(
1108 | window_size=self.patch_embed.patch_shape, num_heads=num_heads
1109 | )
1110 | else:
1111 | self.rel_pos_bias = None
1112 |
1113 | dpr = [
1114 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
1115 | ] # stochastic depth decay rule
1116 | self.use_rel_pos_bias = use_rel_pos_bias
1117 | self.use_checkpoint = use_checkpoint
1118 | self.blocks = nn.ModuleList(
1119 | [
1120 | Block(
1121 | dim=embed_dim,
1122 | num_heads=num_heads,
1123 | mlp_ratio=mlp_ratio,
1124 | qkv_bias=qkv_bias,
1125 | qk_scale=qk_scale,
1126 | drop=drop_rate,
1127 | attn_drop=attn_drop_rate,
1128 | drop_path=dpr[i],
1129 | norm_layer=norm_layer,
1130 | with_cp=with_cp,
1131 | init_values=init_values,
1132 | window_size=self.patch_embed.patch_shape
1133 | if use_rel_pos_bias
1134 | else None,
1135 | )
1136 | for i in range(depth)
1137 | ]
1138 | )
1139 |
1140 | # if self.pos_embed is not None:
1141 | # trunc_normal_(self.pos_embed, std=.02)
1142 | trunc_normal_(self.cls_token, std=0.02)
1143 | self.apply(self._init_weights)
1144 | self.init_weights(pretrained)
1145 |
1146 | # self.fix_init_weight()
1147 |
1148 | def init_weights(self, pretrained=None):
1149 | """Initialize the weights in backbone.
1150 |
1151 | Args:
1152 | pretrained (str, optional): Path to pre-trained weights.
1153 | Defaults to None.
1154 | """
1155 | # pretrained = 'pretrained/beit_large_patch16_512_pt22k_ft22kto1k.pth'
1156 | if isinstance(pretrained, str):
1157 | logger = MMLogger.get_current_instance()
1158 | load_checkpoint(self, pretrained, strict=False, logger=logger)
1159 |
1160 | def fix_init_weight(self):
1161 | def rescale(param, layer_id):
1162 | param.div_(math.sqrt(2.0 * layer_id))
1163 |
1164 | for layer_id, layer in enumerate(self.blocks):
1165 | rescale(layer.attn.proj.weight.data, layer_id + 1)
1166 | rescale(layer.mlp.fc2.weight.data, layer_id + 1)
1167 |
1168 | def _init_weights(self, m):
1169 | if isinstance(m, nn.Linear):
1170 | trunc_normal_(m.weight, std=0.02)
1171 | if isinstance(m, nn.Linear) and m.bias is not None:
1172 | nn.init.constant_(m.bias, 0)
1173 | elif isinstance(m, nn.LayerNorm):
1174 | nn.init.constant_(m.bias, 0)
1175 | nn.init.constant_(m.weight, 1.0)
1176 |
1177 | def get_num_layers(self):
1178 | return len(self.blocks)
1179 |
--------------------------------------------------------------------------------
/rein/models/backbones/clip.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from timm.models.layers import drop_path, trunc_normal_
6 | from mmseg.models.builder import BACKBONES
7 |
8 |
9 | class LayerNorm(nn.LayerNorm):
10 | """Subclass torch's LayerNorm to handle fp16."""
11 |
12 | def forward(self, x: torch.Tensor):
13 | orig_type = x.dtype
14 | ret = super().forward(x.type(torch.float32))
15 | return ret.type(orig_type)
16 |
17 |
18 | class QuickGELU(nn.Module):
19 | def forward(self, x: torch.Tensor):
20 | return x * torch.sigmoid(1.702 * x)
21 |
22 |
23 | class DropPath(nn.Module):
24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
25 |
26 | def __init__(self, drop_prob=None):
27 | super(DropPath, self).__init__()
28 | self.drop_prob = drop_prob
29 |
30 | def forward(self, x):
31 | return drop_path(x, self.drop_prob, self.training)
32 |
33 | def extra_repr(self) -> str:
34 | return "p={}".format(self.drop_prob)
35 |
36 |
37 | class ResidualAttentionBlock(nn.Module):
38 | def __init__(
39 | self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.0
40 | ):
41 | super().__init__()
42 |
43 | self.attn = nn.MultiheadAttention(d_model, n_head)
44 | self.ln_1 = LayerNorm(d_model)
45 | self.mlp = nn.Sequential(
46 | OrderedDict(
47 | [
48 | ("c_fc", nn.Linear(d_model, d_model * 4)),
49 | ("gelu", QuickGELU()),
50 | ("c_proj", nn.Linear(d_model * 4, d_model)),
51 | ]
52 | )
53 | )
54 | self.ln_2 = LayerNorm(d_model)
55 | self.attn_mask = attn_mask
56 |
57 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
58 |
59 | def attention(self, x: torch.Tensor):
60 | self.attn_mask = (
61 | self.attn_mask.to(dtype=x.dtype, device=x.device)
62 | if self.attn_mask is not None
63 | else None
64 | )
65 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
66 |
67 | def forward(self, x: torch.Tensor):
68 | x = x + self.drop_path(self.attention(self.ln_1(x)))
69 | x = x + self.drop_path(self.mlp(self.ln_2(x)))
70 | return x
71 |
72 |
73 | class Transformer(nn.Module):
74 | def __init__(
75 | self,
76 | width: int,
77 | layers: int,
78 | heads: int,
79 | attn_mask: torch.Tensor = None,
80 | drop_path_rate=0.0,
81 | ):
82 | super().__init__()
83 | self.width = width
84 | self.layers = layers
85 | dpr = [
86 | x.item() for x in torch.linspace(0, drop_path_rate, layers)
87 | ] # stochastic depth decay rule
88 | self.resblocks = nn.Sequential(
89 | *[
90 | ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
91 | for i in range(layers)
92 | ]
93 | )
94 |
95 | def forward(self, x: torch.Tensor):
96 | return self.resblocks(x)
97 |
98 |
99 | class Attention(nn.Module):
100 | def __init__(
101 | self,
102 | dim,
103 | num_heads=8,
104 | qkv_bias=False,
105 | qk_scale=None,
106 | attn_drop=0.0,
107 | proj_drop=0.0,
108 | ):
109 | super().__init__()
110 | self.num_heads = num_heads
111 | head_dim = dim // num_heads
112 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
113 | self.scale = qk_scale or head_dim**-0.5
114 |
115 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
116 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
117 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
118 |
119 | self.attn_drop = nn.Dropout(attn_drop)
120 | self.proj = nn.Linear(dim, dim)
121 | self.proj_drop = nn.Dropout(proj_drop)
122 |
123 | def forward(self, q, k, v):
124 | B, N, C = q.shape
125 | assert k.shape == v.shape
126 | B, M, C = k.shape
127 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads)
128 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads)
129 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads)
130 |
131 | attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale
132 |
133 | attn = attn.softmax(dim=-1)
134 |
135 | x = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(B, N, C)
136 |
137 | x = self.proj(x)
138 | x = self.proj_drop(x)
139 | return x
140 |
141 |
142 | class TransformerDecoderLayer(nn.Module):
143 | def __init__(
144 | self,
145 | d_model,
146 | nhead,
147 | dropout=0.1,
148 | ):
149 | super().__init__()
150 | self.self_attn = Attention(d_model, nhead, proj_drop=dropout)
151 | self.cross_attn = Attention(d_model, nhead, proj_drop=dropout)
152 |
153 | self.norm1 = nn.LayerNorm(d_model)
154 | self.norm2 = nn.LayerNorm(d_model)
155 | self.norm3 = nn.LayerNorm(d_model)
156 | self.dropout = nn.Dropout(dropout)
157 |
158 | self.mlp = nn.Sequential(
159 | nn.Linear(d_model, d_model * 4),
160 | nn.GELU(),
161 | nn.Dropout(dropout),
162 | nn.Linear(d_model * 4, d_model),
163 | )
164 |
165 | def forward(self, x, mem):
166 | q = k = v = self.norm1(x)
167 | x = x + self.self_attn(q, k, v)
168 | q = self.norm2(x)
169 | x = x + self.cross_attn(q, mem, mem)
170 | x = x + self.dropout(self.mlp(self.norm3(x)))
171 | return x
172 |
173 |
174 | @BACKBONES.register_module()
175 | class CLIPVisionTransformer(nn.Module):
176 | def __init__(
177 | self,
178 | input_resolution=224,
179 | patch_size=32,
180 | width=768,
181 | layers=12,
182 | heads=12,
183 | output_dim=512,
184 | drop_path_rate=0.0,
185 | out_indices=[3, 5, 7, 11],
186 | pretrained=None,
187 | get_embeddings=False,
188 | **kwargs,
189 | ):
190 | super().__init__()
191 | self.pretrained = pretrained
192 | self.input_resolution = input_resolution
193 | self.output_dim = output_dim
194 | self.patch_size = patch_size
195 | self.conv1 = nn.Conv2d(
196 | in_channels=3,
197 | out_channels=width,
198 | kernel_size=patch_size,
199 | stride=patch_size,
200 | bias=False,
201 | )
202 |
203 | scale = width**-0.5
204 | self.width = width
205 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
206 | self.positional_embedding = nn.Parameter(
207 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
208 | )
209 | self.spatial_size = input_resolution // patch_size
210 | self.ln_pre = LayerNorm(width)
211 | self.get_embeddings = get_embeddings
212 |
213 | self.transformer = Transformer(
214 | width, layers, heads, drop_path_rate=drop_path_rate
215 | )
216 |
217 | self.out_indices = out_indices
218 |
219 | if get_embeddings:
220 | self.ln_post = LayerNorm(width)
221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222 |
223 | embed_dim = width
224 |
225 | def init_weights(self, pretrained=None):
226 | pretrained = pretrained or self.pretrained
227 | if isinstance(pretrained, str):
228 | checkpoint = (
229 | torch.jit.load(pretrained, map_location="cpu").float().state_dict()
230 | )
231 |
232 | state_dict = {}
233 |
234 | for k in checkpoint.keys():
235 | if k.startswith("visual."):
236 | new_k = k.replace("visual.", "")
237 | state_dict[new_k] = checkpoint[k]
238 |
239 | if "positional_embedding" in state_dict.keys():
240 | if (
241 | self.positional_embedding.shape
242 | != state_dict["positional_embedding"].shape
243 | ):
244 | print(
245 | f'Resize the pos_embed shape from {state_dict["positional_embedding"].shape} to {self.positional_embedding.shape}'
246 | )
247 | cls_pos = state_dict["positional_embedding"][0:1, :]
248 | leng = int(state_dict["positional_embedding"][1:,].shape[-2] ** 0.5)
249 | spatial_pos = F.interpolate(
250 | state_dict["positional_embedding"][1:,]
251 | .reshape(1, leng, leng, self.width)
252 | .permute(0, 3, 1, 2),
253 | size=(self.spatial_size, self.spatial_size),
254 | mode="bilinear",
255 | )
256 | spatial_pos = spatial_pos.reshape(
257 | self.width, self.spatial_size * self.spatial_size
258 | ).permute(1, 0)
259 | positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0)
260 | state_dict["positional_embedding"] = positional_embedding
261 | assert (
262 | self.positional_embedding.shape
263 | == state_dict["positional_embedding"].shape
264 | )
265 | conv1 = state_dict["conv1.weight"]
266 | C_o, C_in, H, W = conv1.shape
267 | conv1 = torch.nn.functional.interpolate(
268 | conv1.float(),
269 | size=(self.patch_size, self.patch_size),
270 | mode="bicubic",
271 | align_corners=False,
272 | )
273 | state_dict["conv1.weight"] = conv1
274 |
275 | u, w = self.load_state_dict(state_dict, False)
276 | print(u, w, "are misaligned params in vision transformer")
277 |
278 | def forward(self, x: torch.Tensor):
279 | x = self.conv1(x) # shape = [*, width, grid, grid]
280 | B, C, H, W = x.shape
281 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
282 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
283 | x = torch.cat(
284 | [
285 | self.class_embedding.to(x.dtype)
286 | + torch.zeros(
287 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
288 | ),
289 | x,
290 | ],
291 | dim=1,
292 | ) # shape = [*, grid ** 2 + 1, width]
293 |
294 | pos = self.positional_embedding.to(x.dtype)
295 | cls_pos = pos[0, :] + self.class_embedding.to(x.dtype)
296 | spatial_pos = F.interpolate(
297 | pos[1:,]
298 | .reshape(1, self.spatial_size, self.spatial_size, C)
299 | .permute(0, 3, 1, 2),
300 | size=(H, W),
301 | mode="bilinear",
302 | )
303 | spatial_pos = spatial_pos.reshape(1, C, H * W).permute(0, 2, 1)
304 | pos = torch.cat([cls_pos.reshape(1, 1, C), spatial_pos], dim=1)
305 | x = x + pos
306 | x = self.ln_pre(x)
307 | x = x.permute(1, 0, 2) # NLD -> LND
308 |
309 | features = []
310 | for i, blk in enumerate(self.transformer.resblocks):
311 | x = blk(x)
312 | if i in self.out_indices:
313 | xp = x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H, W)
314 | features.append(xp.contiguous())
315 |
316 | if self.get_embeddings:
317 | x = x.permute(1, 0, 2)
318 | x = self.ln_post(x)
319 | x = x @ self.proj
320 |
321 | global_embedding = x[:, 0]
322 | visual_embedding = (
323 | x[:, 1:].reshape(B, H, W, -1).permute(0, 3, 1, 2)
324 | ) # B C H W
325 |
326 | features.append([global_embedding, visual_embedding])
327 |
328 | return tuple(features)
329 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | from .dino_head import DINOHead
7 | from .mlp import Mlp
8 | from .patch_embed import PatchEmbed
9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10 | from .block import NestedTensorBlock,drop_add_residual_stochastic_depth
11 | from .attention import MemEffAttention
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | import logging
11 | import os
12 | import warnings
13 |
14 | from torch import Tensor
15 | from torch import nn
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22 | try:
23 | if XFORMERS_ENABLED:
24 | from xformers.ops import memory_efficient_attention, unbind
25 |
26 | XFORMERS_AVAILABLE = True
27 | warnings.warn("xFormers is available (Attention)")
28 | else:
29 | warnings.warn("xFormers is disabled (Attention)")
30 | raise ImportError
31 | except ImportError:
32 | XFORMERS_AVAILABLE = False
33 | warnings.warn("xFormers is not available (Attention)")
34 |
35 |
36 | class Attention(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int = 8,
41 | qkv_bias: bool = False,
42 | proj_bias: bool = True,
43 | attn_drop: float = 0.0,
44 | proj_drop: float = 0.0,
45 | ) -> None:
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | self.scale = head_dim**-0.5
50 |
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 |
56 | def forward(self, x: Tensor) -> Tensor:
57 | B, N, C = x.shape
58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59 |
60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61 | attn = q @ k.transpose(-2, -1)
62 |
63 | attn = attn.softmax(dim=-1)
64 | attn = self.attn_drop(attn)
65 |
66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67 | x = self.proj(x)
68 | x = self.proj_drop(x)
69 | return x
70 |
71 |
72 | class MemEffAttention(Attention):
73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74 | if not XFORMERS_AVAILABLE:
75 | if attn_bias is not None:
76 | raise AssertionError("xFormers is required for using nested tensors")
77 | return super().forward(x)
78 |
79 | B, N, C = x.shape
80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81 |
82 | q, k, v = unbind(qkv, 2)
83 |
84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85 | x = x.reshape([B, N, C])
86 |
87 | x = self.proj(x)
88 | x = self.proj_drop(x)
89 | return x
90 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9 | # Copyright (c) Meta Platforms, Inc. and affiliates.
10 | #
11 | # This source code is licensed under the Apache License, Version 2.0
12 | # found in the LICENSE file in the root directory of this source tree.
13 |
14 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
15 |
16 |
17 | import logging
18 | import os
19 | from typing import Callable, List, Any, Tuple, Dict, Union
20 | import warnings
21 |
22 | import torch
23 | from torch import nn, Tensor
24 |
25 | from .attention import Attention, MemEffAttention
26 | from .drop_path import DropPath
27 | from .layer_scale import LayerScale
28 | from .mlp import Mlp
29 |
30 |
31 | logger = logging.getLogger("dinov2")
32 |
33 |
34 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
35 | try:
36 | if XFORMERS_ENABLED:
37 | from xformers.ops import fmha, scaled_index_add, index_select_cat
38 |
39 | XFORMERS_AVAILABLE = True
40 | warnings.warn("xFormers is available (Block)")
41 | else:
42 | warnings.warn("xFormers is disabled (Block)")
43 | raise ImportError
44 | except ImportError:
45 | XFORMERS_AVAILABLE = False
46 |
47 | warnings.warn("xFormers is not available (Block)")
48 |
49 |
50 | class Block(nn.Module):
51 | def __init__(
52 | self,
53 | dim: int,
54 | num_heads: int,
55 | mlp_ratio: float = 4.0,
56 | qkv_bias: bool = False,
57 | proj_bias: bool = True,
58 | ffn_bias: bool = True,
59 | drop: float = 0.0,
60 | attn_drop: float = 0.0,
61 | init_values=None,
62 | drop_path: float = 0.0,
63 | act_layer: Callable[..., nn.Module] = nn.GELU,
64 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
65 | attn_class: Callable[..., nn.Module] = Attention,
66 | ffn_layer: Callable[..., nn.Module] = Mlp,
67 | ) -> None:
68 | super().__init__()
69 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
70 | self.norm1 = norm_layer(dim)
71 | self.attn = attn_class(
72 | dim,
73 | num_heads=num_heads,
74 | qkv_bias=qkv_bias,
75 | proj_bias=proj_bias,
76 | attn_drop=attn_drop,
77 | proj_drop=drop,
78 | )
79 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
80 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
81 |
82 | self.norm2 = norm_layer(dim)
83 | mlp_hidden_dim = int(dim * mlp_ratio)
84 | self.mlp = ffn_layer(
85 | in_features=dim,
86 | hidden_features=mlp_hidden_dim,
87 | act_layer=act_layer,
88 | drop=drop,
89 | bias=ffn_bias,
90 | )
91 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
92 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
93 |
94 | self.sample_drop_ratio = drop_path
95 |
96 | def forward(self, x: Tensor) -> Tensor:
97 | def attn_residual_func(x: Tensor) -> Tensor:
98 | return self.ls1(self.attn(self.norm1(x)))
99 |
100 | def ffn_residual_func(x: Tensor) -> Tensor:
101 | return self.ls2(self.mlp(self.norm2(x)))
102 |
103 | if self.training and self.sample_drop_ratio > 0.1:
104 | # the overhead is compensated only for a drop path rate larger than 0.1
105 | x = drop_add_residual_stochastic_depth(
106 | x,
107 | residual_func=attn_residual_func,
108 | sample_drop_ratio=self.sample_drop_ratio,
109 | )
110 | x = drop_add_residual_stochastic_depth(
111 | x,
112 | residual_func=ffn_residual_func,
113 | sample_drop_ratio=self.sample_drop_ratio,
114 | )
115 | elif self.training and self.sample_drop_ratio > 0.0:
116 | x = x + self.drop_path1(attn_residual_func(x))
117 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
118 | else:
119 | x = x + attn_residual_func(x)
120 | x = x + ffn_residual_func(x)
121 | return x
122 |
123 |
124 | def drop_add_residual_stochastic_depth(
125 | x: Tensor,
126 | residual_func: Callable[[Tensor], Tensor],
127 | sample_drop_ratio: float = 0.0,
128 | ) -> Tensor:
129 | # 1) extract subset using permutation
130 | b, n, d = x.shape
131 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
132 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
133 | x_subset = x[brange]
134 |
135 | # 2) apply residual_func to get residual
136 | residual = residual_func(x_subset)
137 |
138 | x_flat = x.flatten(1)
139 | residual = residual.flatten(1)
140 |
141 | residual_scale_factor = b / sample_subset_size
142 |
143 | # 3) add the residual
144 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
145 | return x_plus_residual.view_as(x)
146 |
147 |
148 | def get_branges_scales(x, sample_drop_ratio=0.0):
149 | b, n, d = x.shape
150 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
151 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
152 | residual_scale_factor = b / sample_subset_size
153 | return brange, residual_scale_factor
154 |
155 |
156 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
157 | if scaling_vector is None:
158 | x_flat = x.flatten(1)
159 | residual = residual.flatten(1)
160 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
161 | else:
162 | x_plus_residual = scaled_index_add(
163 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
164 | )
165 | return x_plus_residual
166 |
167 |
168 | attn_bias_cache: Dict[Tuple, Any] = {}
169 |
170 |
171 | def get_attn_bias_and_cat(x_list, branges=None):
172 | """
173 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
174 | """
175 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
176 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
177 | if all_shapes not in attn_bias_cache.keys():
178 | seqlens = []
179 | for b, x in zip(batch_sizes, x_list):
180 | for _ in range(b):
181 | seqlens.append(x.shape[1])
182 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
183 | attn_bias._batch_sizes = batch_sizes
184 | attn_bias_cache[all_shapes] = attn_bias
185 |
186 | if branges is not None:
187 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
188 | else:
189 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
190 | cat_tensors = torch.cat(tensors_bs1, dim=1)
191 |
192 | return attn_bias_cache[all_shapes], cat_tensors
193 |
194 |
195 | def drop_add_residual_stochastic_depth_list(
196 | x_list: List[Tensor],
197 | residual_func: Callable[[Tensor, Any], Tensor],
198 | sample_drop_ratio: float = 0.0,
199 | scaling_vector=None,
200 | ) -> Tensor:
201 | # 1) generate random set of indices for dropping samples in the batch
202 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
203 | branges = [s[0] for s in branges_scales]
204 | residual_scale_factors = [s[1] for s in branges_scales]
205 |
206 | # 2) get attention bias and index+concat the tensors
207 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
208 |
209 | # 3) apply residual_func to get residual, and split the result
210 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
211 |
212 | outputs = []
213 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
214 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
215 | return outputs
216 |
217 |
218 | class NestedTensorBlock(Block):
219 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
220 | """
221 | x_list contains a list of tensors to nest together and run
222 | """
223 | assert isinstance(self.attn, MemEffAttention)
224 |
225 | if self.training and self.sample_drop_ratio > 0.0:
226 |
227 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
228 | return self.attn(self.norm1(x), attn_bias=attn_bias)
229 |
230 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
231 | return self.mlp(self.norm2(x))
232 |
233 | x_list = drop_add_residual_stochastic_depth_list(
234 | x_list,
235 | residual_func=attn_residual_func,
236 | sample_drop_ratio=self.sample_drop_ratio,
237 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
238 | )
239 | x_list = drop_add_residual_stochastic_depth_list(
240 | x_list,
241 | residual_func=ffn_residual_func,
242 | sample_drop_ratio=self.sample_drop_ratio,
243 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
244 | )
245 | return x_list
246 | else:
247 |
248 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
249 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
250 |
251 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
252 | return self.ls2(self.mlp(self.norm2(x)))
253 |
254 | attn_bias, x = get_attn_bias_and_cat(x_list)
255 | x = x + attn_residual_func(x, attn_bias=attn_bias)
256 | x = x + ffn_residual_func(x)
257 | return attn_bias.split(x)
258 |
259 | def forward(self, x_or_x_list):
260 | if isinstance(x_or_x_list, Tensor):
261 | return super().forward(x_or_x_list)
262 | elif isinstance(x_or_x_list, list):
263 | if not XFORMERS_AVAILABLE:
264 | raise AssertionError("xFormers is required for using nested tensors")
265 | return self.forward_nested(x_or_x_list)
266 | else:
267 | raise AssertionError
268 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.init import trunc_normal_
9 | from torch.nn.utils import weight_norm
10 |
11 |
12 | class DINOHead(nn.Module):
13 | def __init__(
14 | self,
15 | in_dim,
16 | out_dim,
17 | use_bn=False,
18 | nlayers=3,
19 | hidden_dim=2048,
20 | bottleneck_dim=256,
21 | mlp_bias=True,
22 | ):
23 | super().__init__()
24 | nlayers = max(nlayers, 1)
25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26 | self.apply(self._init_weights)
27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28 | self.last_layer.weight_g.data.fill_(1)
29 |
30 | def _init_weights(self, m):
31 | if isinstance(m, nn.Linear):
32 | trunc_normal_(m.weight, std=0.02)
33 | if isinstance(m, nn.Linear) and m.bias is not None:
34 | nn.init.constant_(m.bias, 0)
35 |
36 | def forward(self, x):
37 | x = self.mlp(x)
38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40 | x = self.last_layer(x)
41 | return x
42 |
43 |
44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45 | if nlayers == 1:
46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47 | else:
48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49 | if use_bn:
50 | layers.append(nn.BatchNorm1d(hidden_dim))
51 | layers.append(nn.GELU())
52 | for _ in range(nlayers - 2):
53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54 | if use_bn:
55 | layers.append(nn.BatchNorm1d(hidden_dim))
56 | layers.append(nn.GELU())
57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58 | return nn.Sequential(*layers)
59 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9 |
10 |
11 | from torch import nn
12 |
13 |
14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15 | if drop_prob == 0.0 or not training:
16 | return x
17 | keep_prob = 1 - drop_prob
18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20 | if keep_prob > 0.0:
21 | random_tensor.div_(keep_prob)
22 | output = x * random_tensor
23 | return output
24 |
25 |
26 | class DropPath(nn.Module):
27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28 |
29 | def __init__(self, drop_prob=None):
30 | super(DropPath, self).__init__()
31 | self.drop_prob = drop_prob
32 |
33 | def forward(self, x):
34 | return drop_path(x, self.drop_prob, self.training)
35 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7 |
8 | from typing import Union
9 |
10 | import torch
11 | from torch import Tensor
12 | from torch import nn
13 |
14 |
15 | class LayerScale(nn.Module):
16 | def __init__(
17 | self,
18 | dim: int,
19 | init_values: Union[float, Tensor] = 1e-5,
20 | inplace: bool = False,
21 | ) -> None:
22 | super().__init__()
23 | self.inplace = inplace
24 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
25 |
26 | def forward(self, x: Tensor) -> Tensor:
27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
28 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9 |
10 |
11 | from typing import Callable, Optional
12 |
13 | from torch import Tensor, nn
14 |
15 |
16 | class Mlp(nn.Module):
17 | def __init__(
18 | self,
19 | in_features: int,
20 | hidden_features: Optional[int] = None,
21 | out_features: Optional[int] = None,
22 | act_layer: Callable[..., nn.Module] = nn.GELU,
23 | drop: float = 0.0,
24 | bias: bool = True,
25 | ) -> None:
26 | super().__init__()
27 | out_features = out_features or in_features
28 | hidden_features = hidden_features or in_features
29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30 | self.act = act_layer()
31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32 | self.drop = nn.Dropout(drop)
33 |
34 | def forward(self, x: Tensor) -> Tensor:
35 | x = self.fc1(x)
36 | x = self.act(x)
37 | x = self.drop(x)
38 | x = self.fc2(x)
39 | x = self.drop(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9 |
10 | from typing import Callable, Optional, Tuple, Union
11 |
12 | from torch import Tensor
13 | import torch.nn as nn
14 |
15 |
16 | def make_2tuple(x):
17 | if isinstance(x, tuple):
18 | assert len(x) == 2
19 | return x
20 |
21 | assert isinstance(x, int)
22 | return (x, x)
23 |
24 |
25 | class PatchEmbed(nn.Module):
26 | """
27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28 |
29 | Args:
30 | img_size: Image size.
31 | patch_size: Patch token size.
32 | in_chans: Number of input image channels.
33 | embed_dim: Number of linear projection output channels.
34 | norm_layer: Normalization layer.
35 | """
36 |
37 | def __init__(
38 | self,
39 | img_size: Union[int, Tuple[int, int]] = 224,
40 | patch_size: Union[int, Tuple[int, int]] = 16,
41 | in_chans: int = 3,
42 | embed_dim: int = 768,
43 | norm_layer: Optional[Callable] = None,
44 | flatten_embedding: bool = True,
45 | ) -> None:
46 | super().__init__()
47 |
48 | image_HW = make_2tuple(img_size)
49 | patch_HW = make_2tuple(patch_size)
50 | patch_grid_size = (
51 | image_HW[0] // patch_HW[0],
52 | image_HW[1] // patch_HW[1],
53 | )
54 |
55 | self.img_size = image_HW
56 | self.patch_size = patch_HW
57 | self.patches_resolution = patch_grid_size
58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59 |
60 | self.in_chans = in_chans
61 | self.embed_dim = embed_dim
62 |
63 | self.flatten_embedding = flatten_embedding
64 |
65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | _, _, H, W = x.shape
70 | patch_H, patch_W = self.patch_size
71 |
72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74 |
75 | x = self.proj(x) # B C H W
76 | H, W = x.size(2), x.size(3)
77 | x = x.flatten(2).transpose(1, 2) # B HW C
78 | x = self.norm(x)
79 | if not self.flatten_embedding:
80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81 | return x
82 |
83 | def flops(self) -> float:
84 | Ho, Wo = self.patches_resolution
85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86 | if self.norm is not None:
87 | flops += Ho * Wo * self.embed_dim
88 | return flops
89 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 | from typing import Callable, Optional
8 | import warnings
9 |
10 | from torch import Tensor, nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class SwiGLUFFN(nn.Module):
15 | def __init__(
16 | self,
17 | in_features: int,
18 | hidden_features: Optional[int] = None,
19 | out_features: Optional[int] = None,
20 | act_layer: Callable[..., nn.Module] = None,
21 | drop: float = 0.0,
22 | bias: bool = True,
23 | ) -> None:
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29 |
30 | def forward(self, x: Tensor) -> Tensor:
31 | x12 = self.w12(x)
32 | x1, x2 = x12.chunk(2, dim=-1)
33 | hidden = F.silu(x1) * x2
34 | return self.w3(hidden)
35 |
36 |
37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38 | try:
39 | if XFORMERS_ENABLED:
40 | from xformers.ops import SwiGLU
41 |
42 | XFORMERS_AVAILABLE = True
43 | warnings.warn("xFormers is available (SwiGLU)")
44 | else:
45 | warnings.warn("xFormers is disabled (SwiGLU)")
46 | raise ImportError
47 | except ImportError:
48 | SwiGLU = SwiGLUFFN
49 | XFORMERS_AVAILABLE = False
50 |
51 | warnings.warn("xFormers is not available (SwiGLU)")
52 |
53 |
54 | class SwiGLUFFNFused(SwiGLU):
55 | def __init__(
56 | self,
57 | in_features: int,
58 | hidden_features: Optional[int] = None,
59 | out_features: Optional[int] = None,
60 | act_layer: Callable[..., nn.Module] = None,
61 | drop: float = 0.0,
62 | bias: bool = True,
63 | ) -> None:
64 | out_features = out_features or in_features
65 | hidden_features = hidden_features or in_features
66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67 | super().__init__(
68 | in_features=in_features,
69 | hidden_features=hidden_features,
70 | out_features=out_features,
71 | bias=bias,
72 | )
73 |
--------------------------------------------------------------------------------
/rein/models/backbones/dino_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | from functools import partial
11 | import math
12 | from typing import Sequence, Tuple, Union, Callable
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.utils.checkpoint
17 | import torch.nn.functional as F
18 | from .dino_layers import (
19 | Mlp,
20 | PatchEmbed,
21 | SwiGLUFFNFused,
22 | MemEffAttention,
23 | NestedTensorBlock as Block,
24 | )
25 |
26 |
27 | def named_apply(
28 | fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
29 | ) -> nn.Module:
30 | if not depth_first and include_root:
31 | fn(module=module, name=name)
32 | for child_name, child_module in module.named_children():
33 | child_name = ".".join((name, child_name)) if name else child_name
34 | named_apply(
35 | fn=fn,
36 | module=child_module,
37 | name=child_name,
38 | depth_first=depth_first,
39 | include_root=True,
40 | )
41 | if depth_first and include_root:
42 | fn(module=module, name=name)
43 | return module
44 |
45 |
46 | class BlockChunk(nn.ModuleList):
47 | def forward(self, x):
48 | for b in self:
49 | x = b(x)
50 | return x
51 |
52 |
53 | class DinoVisionTransformer(nn.Module):
54 | def __init__(
55 | self,
56 | img_size=224,
57 | patch_size=16,
58 | in_chans=3,
59 | embed_dim=768,
60 | depth=12,
61 | num_heads=12,
62 | mlp_ratio=4.0,
63 | qkv_bias=True,
64 | ffn_bias=True,
65 | proj_bias=True,
66 | drop_path_rate=0.0,
67 | drop_path_uniform=False,
68 | init_values=None, # for layerscale: None or 0 => no layerscale
69 | embed_layer=PatchEmbed,
70 | act_layer=nn.GELU,
71 | block_fn=partial(Block, attn_class=MemEffAttention),
72 | ffn_layer="mlp",
73 | block_chunks=1,
74 | out_indices=[7, 11, 15, 23],
75 | init_cfg=None,
76 | ):
77 | """
78 | Args:
79 | img_size (int, tuple): input image size
80 | patch_size (int, tuple): patch size
81 | in_chans (int): number of input channels
82 | embed_dim (int): embedding dimension
83 | depth (int): depth of transformer
84 | num_heads (int): number of attention heads
85 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
86 | qkv_bias (bool): enable bias for qkv if True
87 | proj_bias (bool): enable bias for proj in attn if True
88 | ffn_bias (bool): enable bias for ffn if True
89 | drop_path_rate (float): stochastic depth rate
90 | drop_path_uniform (bool): apply uniform drop rate across blocks
91 | weight_init (str): weight init scheme
92 | init_values (float): layer-scale init values
93 | embed_layer (nn.Module): patch embedding layer
94 | act_layer (nn.Module): MLP activation layer
95 | block_fn (nn.Module): transformer block class
96 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
97 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
98 | """
99 | super().__init__()
100 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
101 | self.out_indices = out_indices
102 |
103 | self.num_features = (
104 | self.embed_dim
105 | ) = embed_dim # num_features for consistency with other models
106 | self.num_tokens = 1
107 | self.n_blocks = depth
108 | self.num_heads = num_heads
109 | self.patch_size = patch_size
110 |
111 | self.patch_embed = embed_layer(
112 | img_size=img_size,
113 | patch_size=patch_size,
114 | in_chans=in_chans,
115 | embed_dim=embed_dim,
116 | )
117 | num_patches = self.patch_embed.num_patches
118 |
119 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
120 | self.pos_embed = nn.Parameter(
121 | torch.zeros(1, num_patches + self.num_tokens, embed_dim)
122 | )
123 |
124 | if drop_path_uniform is True:
125 | dpr = [drop_path_rate] * depth
126 | else:
127 | dpr = [
128 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
129 | ] # stochastic depth decay rule
130 |
131 | if ffn_layer == "mlp":
132 | ffn_layer = Mlp
133 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
134 | ffn_layer = SwiGLUFFNFused
135 | elif ffn_layer == "identity":
136 |
137 | def f(*args, **kwargs):
138 | return nn.Identity()
139 |
140 | ffn_layer = f
141 | else:
142 | raise NotImplementedError
143 |
144 | blocks_list = [
145 | block_fn(
146 | dim=embed_dim,
147 | num_heads=num_heads,
148 | mlp_ratio=mlp_ratio,
149 | qkv_bias=qkv_bias,
150 | proj_bias=proj_bias,
151 | ffn_bias=ffn_bias,
152 | drop_path=dpr[i],
153 | norm_layer=norm_layer,
154 | act_layer=act_layer,
155 | ffn_layer=ffn_layer,
156 | init_values=init_values,
157 | )
158 | for i in range(depth)
159 | ]
160 | if block_chunks > 0:
161 | self.chunked_blocks = True
162 | chunked_blocks = []
163 | chunksize = depth // block_chunks
164 | for i in range(0, depth, chunksize):
165 | # this is to keep the block index consistent if we chunk the block list
166 | chunked_blocks.append(
167 | [nn.Identity()] * i + blocks_list[i : i + chunksize]
168 | )
169 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
170 | else:
171 | self.chunked_blocks = False
172 | self.blocks = nn.ModuleList(blocks_list)
173 |
174 | self.norm = norm_layer(embed_dim)
175 | self.head = nn.Identity()
176 |
177 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
178 |
179 | def interpolate_pos_encoding(self, x, w, h):
180 | previous_dtype = x.dtype
181 | npatch = x.shape[1] - 1
182 | N = self.pos_embed.shape[1] - 1
183 | if npatch == N and w == h:
184 | return self.pos_embed
185 | pos_embed = self.pos_embed.float()
186 | class_pos_embed = pos_embed[:, 0]
187 | patch_pos_embed = pos_embed[:, 1:]
188 | dim = x.shape[-1]
189 | w0 = w // self.patch_size
190 | h0 = h // self.patch_size
191 | # we add a small number to avoid floating point error in the interpolation
192 | # see discussion at https://github.com/facebookresearch/dino/issues/8
193 | w0, h0 = w0 + 0.1, h0 + 0.1
194 |
195 | patch_pos_embed = nn.functional.interpolate(
196 | patch_pos_embed.reshape(
197 | 1, int(math.sqrt(N)), int(math.sqrt(N)), dim
198 | ).permute(0, 3, 1, 2),
199 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
200 | mode="bicubic",
201 | )
202 |
203 | assert (
204 | int(w0) == patch_pos_embed.shape[-2]
205 | and int(h0) == patch_pos_embed.shape[-1]
206 | )
207 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
208 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
209 | previous_dtype
210 | )
211 |
212 | def prepare_tokens_with_masks(self, x, masks=None):
213 | B, nc, w, h = x.shape
214 | x = self.patch_embed(x)
215 | if masks is not None:
216 | x = torch.where(
217 | masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
218 | )
219 |
220 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
221 | x = x + self.interpolate_pos_encoding(x, w, h)
222 |
223 | return x
224 |
225 | def forward_features_list(self, x_list, masks_list):
226 | x = [
227 | self.prepare_tokens_with_masks(x, masks)
228 | for x, masks in zip(x_list, masks_list)
229 | ]
230 | for blk in self.blocks:
231 | x = blk(x)
232 |
233 | all_x = x
234 | output = []
235 | for x, masks in zip(all_x, masks_list):
236 | x_norm = self.norm(x)
237 | output.append(
238 | {
239 | "x_norm_clstoken": x_norm[:, 0],
240 | "x_norm_patchtokens": x_norm[:, 1:],
241 | "x_prenorm": x,
242 | "masks": masks,
243 | }
244 | )
245 | return output
246 |
247 | def forward_features(self, x, masks=None):
248 | B, _, h, w = x.shape
249 | if isinstance(x, list):
250 | return self.forward_features_list(x, masks)
251 |
252 | x = self.prepare_tokens_with_masks(x, masks)
253 | outs = []
254 | for idx, blk in enumerate(self.blocks):
255 | x = blk(x)
256 | if idx in self.out_indices:
257 | outs.append(
258 | x[:, 1:, :]
259 | .permute(0, 2, 1)
260 | .reshape(B, -1, h // self.patch_size, w // self.patch_size)
261 | .contiguous()
262 | )
263 | return outs
264 |
265 | def _get_intermediate_layers_not_chunked(self, x, n=1):
266 | x = self.prepare_tokens_with_masks(x)
267 | # If n is an int, take the n last blocks. If it's a list, take them
268 | output, total_block_len = [], len(self.blocks)
269 | blocks_to_take = (
270 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n
271 | )
272 | for i, blk in enumerate(self.blocks):
273 | x = blk(x)
274 | if i in blocks_to_take:
275 | output.append(x)
276 | assert len(output) == len(
277 | blocks_to_take
278 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
279 | return output
280 |
281 | def _get_intermediate_layers_chunked(self, x, n=1):
282 | x = self.prepare_tokens_with_masks(x)
283 | output, i, total_block_len = [], 0, len(self.blocks[-1])
284 | # If n is an int, take the n last blocks. If it's a list, take them
285 | blocks_to_take = (
286 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n
287 | )
288 | for block_chunk in self.blocks:
289 | for blk in block_chunk[i:]: # Passing the nn.Identity()
290 | x = blk(x)
291 | if i in blocks_to_take:
292 | output.append(x)
293 | i += 1
294 | assert len(output) == len(
295 | blocks_to_take
296 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
297 | return output
298 |
299 | def get_intermediate_layers(
300 | self,
301 | x: torch.Tensor,
302 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
303 | reshape: bool = False,
304 | return_class_token: bool = False,
305 | norm=True,
306 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
307 | if self.chunked_blocks:
308 | outputs = self._get_intermediate_layers_chunked(x, n)
309 | else:
310 | outputs = self._get_intermediate_layers_not_chunked(x, n)
311 | if norm:
312 | outputs = [self.norm(out) for out in outputs]
313 | class_tokens = [out[:, 0] for out in outputs]
314 | outputs = [out[:, 1:] for out in outputs]
315 | if reshape:
316 | B, _, w, h = x.shape
317 | outputs = [
318 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
319 | .permute(0, 3, 1, 2)
320 | .contiguous()
321 | for out in outputs
322 | ]
323 | if return_class_token:
324 | return tuple(zip(outputs, class_tokens))
325 | return tuple(outputs)
326 |
327 | def forward(self, *args, **kwargs):
328 | ret = self.forward_features(*args, **kwargs)
329 | # if isinstance(ret[0], torch.Tensor):
330 | # ret[0] = F.interpolate(
331 | # ret[0], scale_factor=4, mode="bilinear", align_corners=False
332 | # )
333 | # ret[1] = F.interpolate(
334 | # ret[1], scale_factor=2, mode="bilinear", align_corners=False
335 | # )
336 | # ret[3] = F.interpolate(
337 | # ret[3], scale_factor=0.5, mode="bilinear", align_corners=False
338 | # )
339 | # else:
340 | # ret[0][0] = F.interpolate(
341 | # ret[0][0], scale_factor=4, mode="bilinear", align_corners=False
342 | # )
343 | # ret[0][1] = F.interpolate(
344 | # ret[0][1], scale_factor=2, mode="bilinear", align_corners=False
345 | # )
346 | # ret[0][3] = F.interpolate(
347 | # ret[0][3], scale_factor=0.5, mode="bilinear", align_corners=False
348 | # )
349 | return ret
--------------------------------------------------------------------------------
/rein/models/backbones/eva_02.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # By Hangbo Bao
7 | # Based on timm, mmseg, setr, xcit and swin code bases
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # https://github.com/fudan-zvg/SETR
10 | # https://github.com/facebookresearch/xcit/
11 | # https://github.com/microsoft/Swin-Transformer
12 | # --------------------------------------------------------'
13 |
14 | import torch
15 | from functools import partial
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | import torch.utils.checkpoint as checkpoint
19 |
20 | from timm.models.layers import drop_path, to_2tuple, trunc_normal_
21 |
22 | from .beit import load_checkpoint
23 | from mmengine.logging import MMLogger
24 | from mmseg.models.builder import BACKBONES
25 | from mmcv.cnn import build_norm_layer
26 | import xformers.ops as xops
27 | # from apex.normalization import FusedLayerNorm
28 | # from apex.normalization import FusedLayerNorm
29 |
30 |
31 | from math import pi
32 | from einops import rearrange, repeat
33 |
34 |
35 | def broadcat(tensors, dim=-1):
36 | num_tensors = len(tensors)
37 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
38 | assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
39 | shape_len = list(shape_lens)[0]
40 | dim = (dim + shape_len) if dim < 0 else dim
41 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
42 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
43 | assert all(
44 | [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
45 | ), "invalid dimensions for broadcastable concatentation"
46 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
47 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
48 | expanded_dims.insert(dim, (dim, dims[dim]))
49 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
50 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
51 | return torch.cat(tensors, dim=dim)
52 |
53 |
54 | def rotate_half(x):
55 | x = rearrange(x, "... (d r) -> ... d r", r=2)
56 | x1, x2 = x.unbind(dim=-1)
57 | x = torch.stack((-x2, x1), dim=-1)
58 | return rearrange(x, "... d r -> ... (d r)")
59 |
60 |
61 | class VisionRotaryEmbedding(nn.Module):
62 | def __init__(
63 | self,
64 | dim,
65 | pt_seq_len,
66 | ft_seq_len=None,
67 | custom_freqs=None,
68 | freqs_for="lang",
69 | theta=10000,
70 | max_freq=10,
71 | num_freqs=1,
72 | ):
73 | super().__init__()
74 | if custom_freqs:
75 | freqs = custom_freqs
76 | elif freqs_for == "lang":
77 | freqs = 1.0 / (
78 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
79 | )
80 | elif freqs_for == "pixel":
81 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
82 | elif freqs_for == "constant":
83 | freqs = torch.ones(num_freqs).float()
84 | else:
85 | raise ValueError(f"unknown modality {freqs_for}")
86 |
87 | if ft_seq_len is None:
88 | ft_seq_len = pt_seq_len
89 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
90 |
91 | freqs_h = torch.einsum("..., f -> ... f", t, freqs)
92 | freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
93 |
94 | freqs_w = torch.einsum("..., f -> ... f", t, freqs)
95 | freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
96 |
97 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
98 |
99 | self.register_buffer("freqs_cos", freqs.cos())
100 | self.register_buffer("freqs_sin", freqs.sin())
101 |
102 | print("======== shape of rope freq", self.freqs_cos.shape, "========")
103 |
104 | def forward(self, t, start_index=0):
105 | rot_dim = self.freqs_cos.shape[-1]
106 | end_index = start_index + rot_dim
107 | assert (
108 | rot_dim <= t.shape[-1]
109 | ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
110 | t_left, t, t_right = (
111 | t[..., :start_index],
112 | t[..., start_index:end_index],
113 | t[..., end_index:],
114 | )
115 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
116 | return torch.cat((t_left, t, t_right), dim=-1)
117 |
118 |
119 | class VisionRotaryEmbeddingFast(nn.Module):
120 | def __init__(
121 | self,
122 | dim,
123 | pt_seq_len,
124 | ft_seq_len=None,
125 | custom_freqs=None,
126 | freqs_for="lang",
127 | theta=10000,
128 | max_freq=10,
129 | num_freqs=1,
130 | ):
131 | super().__init__()
132 | if custom_freqs:
133 | freqs = custom_freqs
134 | elif freqs_for == "lang":
135 | freqs = 1.0 / (
136 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
137 | )
138 | elif freqs_for == "pixel":
139 | freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
140 | elif freqs_for == "constant":
141 | freqs = torch.ones(num_freqs).float()
142 | else:
143 | raise ValueError(f"unknown modality {freqs_for}")
144 |
145 | if ft_seq_len is None:
146 | ft_seq_len = pt_seq_len
147 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
148 |
149 | freqs = torch.einsum("..., f -> ... f", t, freqs)
150 | freqs = repeat(freqs, "... n -> ... (n r)", r=2)
151 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
152 |
153 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
154 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
155 |
156 | self.register_buffer("freqs_cos", freqs_cos)
157 | self.register_buffer("freqs_sin", freqs_sin)
158 |
159 | def forward(self, t):
160 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
161 |
162 |
163 | class DropPath(nn.Module):
164 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
165 |
166 | def __init__(self, drop_prob=None):
167 | super(DropPath, self).__init__()
168 | self.drop_prob = drop_prob
169 |
170 | def forward(self, x):
171 | return drop_path(x, self.drop_prob, self.training)
172 |
173 | def extra_repr(self) -> str:
174 | return "p={}".format(self.drop_prob)
175 |
176 |
177 | class Mlp(nn.Module):
178 | def __init__(
179 | self,
180 | in_features,
181 | hidden_features=None,
182 | out_features=None,
183 | act_layer=nn.GELU,
184 | drop=0.0,
185 | ):
186 | super().__init__()
187 | out_features = out_features or in_features
188 | hidden_features = hidden_features or in_features
189 | self.fc1 = nn.Linear(in_features, hidden_features)
190 | self.act = act_layer()
191 | self.fc2 = nn.Linear(hidden_features, out_features)
192 | self.drop = nn.Dropout(drop)
193 |
194 | def forward(self, x):
195 | x = self.fc1(x)
196 | x = self.act(x)
197 | # x = self.drop(x)
198 | # commit this for the orignal BERT implement
199 | x = self.fc2(x)
200 | x = self.drop(x)
201 | return x
202 |
203 |
204 | class SwiGLU(nn.Module):
205 | def __init__(
206 | self,
207 | in_features,
208 | hidden_features=None,
209 | out_features=None,
210 | act_layer=nn.SiLU,
211 | drop=0.0,
212 | norm_layer=nn.LayerNorm,
213 | subln=False,
214 | ):
215 | super().__init__()
216 | out_features = out_features or in_features
217 | hidden_features = hidden_features or in_features
218 |
219 | self.w1 = nn.Linear(in_features, hidden_features)
220 | self.w2 = nn.Linear(in_features, hidden_features)
221 |
222 | self.act = act_layer()
223 | if isinstance(norm_layer, dict):
224 | self.ffn_ln = (
225 | build_norm_layer(norm_layer, hidden_features)[1]
226 | if subln
227 | else nn.Identity()
228 | )
229 | else:
230 | self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
231 | self.w3 = nn.Linear(hidden_features, out_features)
232 |
233 | self.drop = nn.Dropout(drop)
234 |
235 | def forward(self, x):
236 | x1 = self.w1(x)
237 | x2 = self.w2(x)
238 | hidden = self.act(x1) * x2
239 | x = self.ffn_ln(hidden)
240 | x = self.w3(x)
241 | x = self.drop(x)
242 | return x
243 |
244 |
245 | class Attention(nn.Module):
246 | def __init__(
247 | self,
248 | dim,
249 | num_heads=8,
250 | qkv_bias=False,
251 | qk_scale=None,
252 | attn_drop=0.0,
253 | proj_drop=0.0,
254 | window_size=None,
255 | attn_head_dim=None,
256 | subln=False,
257 | norm_layer=nn.LayerNorm,
258 | xattn=False,
259 | rope=None,
260 | ):
261 | super().__init__()
262 | self.num_heads = num_heads
263 | head_dim = dim // num_heads
264 | if attn_head_dim is not None:
265 | head_dim = attn_head_dim
266 | all_head_dim = head_dim * self.num_heads
267 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
268 | self.scale = qk_scale or head_dim**-0.5
269 |
270 | self.subln = subln
271 | if self.subln:
272 | self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
273 | self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
274 | self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
275 | else:
276 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
277 |
278 | if qkv_bias:
279 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
280 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
281 | else:
282 | self.q_bias = None
283 | self.v_bias = None
284 |
285 | if window_size:
286 | self.window_size = window_size
287 | self.num_relative_distance = (2 * window_size[0] - 1) * (
288 | 2 * window_size[1] - 1
289 | ) + 3
290 | self.relative_position_bias_table = nn.Parameter(
291 | torch.zeros(self.num_relative_distance, num_heads)
292 | ) # 2*Wh-1 * 2*Ww-1, nH
293 | # cls to token & token 2 cls & cls to cls
294 |
295 | # get pair-wise relative position index for each token inside the window
296 | coords_h = torch.arange(window_size[0])
297 | coords_w = torch.arange(window_size[1])
298 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
299 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
300 | relative_coords = (
301 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
302 | ) # 2, Wh*Ww, Wh*Ww
303 | relative_coords = relative_coords.permute(
304 | 1, 2, 0
305 | ).contiguous() # Wh*Ww, Wh*Ww, 2
306 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
307 | relative_coords[:, :, 1] += window_size[1] - 1
308 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
309 | relative_position_index = torch.zeros(
310 | size=(window_size[0] * window_size[1] + 1,) * 2,
311 | dtype=relative_coords.dtype,
312 | )
313 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
314 | relative_position_index[0, 0:] = self.num_relative_distance - 3
315 | relative_position_index[0:, 0] = self.num_relative_distance - 2
316 | relative_position_index[0, 0] = self.num_relative_distance - 1
317 |
318 | self.register_buffer("relative_position_index", relative_position_index)
319 |
320 | # trunc_normal_(self.relative_position_bias_table, std=.0)
321 | else:
322 | self.window_size = None
323 | self.relative_position_bias_table = None
324 | self.relative_position_index = None
325 |
326 | self.attn_drop = nn.Dropout(attn_drop)
327 | self.proj = nn.Linear(all_head_dim, dim)
328 | self.proj_drop = nn.Dropout(proj_drop)
329 |
330 | self.xattn = xattn
331 | self.rope = rope
332 |
333 | def forward(self, x, rel_pos_bias=None):
334 | B, N, C = x.shape
335 |
336 | if self.subln:
337 | q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
338 | k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
339 | v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
340 |
341 | q = q.reshape(B, N, self.num_heads, -1).permute(
342 | 0, 2, 1, 3
343 | ) # B, num_heads, N, C
344 | k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
345 | v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
346 | else:
347 | qkv_bias = None
348 | if self.q_bias is not None:
349 | qkv_bias = torch.cat(
350 | (
351 | self.q_bias,
352 | torch.zeros_like(self.v_bias, requires_grad=False),
353 | self.v_bias,
354 | )
355 | )
356 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
357 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
358 | 2, 0, 3, 1, 4
359 | ) # 3, B, num_heads, N, C
360 | q, k, v = qkv[0], qkv[1], qkv[2]
361 |
362 | if self.rope:
363 | q_t = q[:, :, 1:, :]
364 | ro_q_t = self.rope(q_t)
365 | q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
366 |
367 | k_t = k[:, :, 1:, :]
368 | ro_k_t = self.rope(k_t)
369 | k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
370 |
371 | if self.xattn:
372 | q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
373 | k = k.permute(0, 2, 1, 3)
374 | v = v.permute(0, 2, 1, 3)
375 |
376 | x = xops.memory_efficient_attention(q, k, v)
377 | x = x.reshape(B, N, -1)
378 | x = self.proj(x)
379 | x = self.proj_drop(x)
380 | else:
381 | q = q * self.scale
382 | attn = q @ k.transpose(-2, -1)
383 |
384 | if self.relative_position_bias_table is not None:
385 | relative_position_bias = self.relative_position_bias_table[
386 | self.relative_position_index.view(-1)
387 | ].view(
388 | self.window_size[0] * self.window_size[1] + 1,
389 | self.window_size[0] * self.window_size[1] + 1,
390 | -1,
391 | ) # Wh*Ww,Wh*Ww,nH
392 | relative_position_bias = relative_position_bias.permute(
393 | 2, 0, 1
394 | ).contiguous() # nH, Wh*Ww, Wh*Ww
395 | attn = attn + relative_position_bias.unsqueeze(0)
396 |
397 | if rel_pos_bias is not None:
398 | attn = attn + rel_pos_bias
399 |
400 | attn = attn.softmax(dim=-1)
401 | attn = self.attn_drop(attn)
402 |
403 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
404 | x = self.proj(x)
405 | x = self.proj_drop(x)
406 |
407 | return x
408 |
409 |
410 | class Block(nn.Module):
411 | def __init__(
412 | self,
413 | dim,
414 | num_heads,
415 | mlp_ratio=4.0,
416 | qkv_bias=False,
417 | qk_scale=None,
418 | drop=0.0,
419 | attn_drop=0.0,
420 | drop_path=0.0,
421 | init_values=None,
422 | act_layer=nn.GELU,
423 | norm_layer=nn.LayerNorm,
424 | window_size=None,
425 | attn_head_dim=None,
426 | subln=False,
427 | xattn=False,
428 | naiveswiglu=False,
429 | rope=None,
430 | ):
431 | super().__init__()
432 | if isinstance(norm_layer, dict):
433 | self.norm1 = build_norm_layer(norm_layer, dim)[1]
434 | else:
435 | self.norm1 = norm_layer(dim)
436 | self.attn = Attention(
437 | dim,
438 | num_heads=num_heads,
439 | qkv_bias=qkv_bias,
440 | qk_scale=qk_scale,
441 | attn_drop=attn_drop,
442 | proj_drop=drop,
443 | window_size=window_size,
444 | attn_head_dim=attn_head_dim,
445 | subln=subln,
446 | norm_layer=norm_layer,
447 | xattn=xattn,
448 | rope=rope,
449 | )
450 |
451 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
452 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
453 | if isinstance(norm_layer, dict):
454 | self.norm2 = build_norm_layer(norm_layer, dim)[1]
455 | else:
456 | self.norm2 = norm_layer(dim)
457 | mlp_hidden_dim = int(dim * mlp_ratio)
458 |
459 | if naiveswiglu:
460 | self.mlp = SwiGLU(
461 | in_features=dim,
462 | hidden_features=mlp_hidden_dim,
463 | subln=subln,
464 | norm_layer=norm_layer,
465 | )
466 | else:
467 | self.mlp = Mlp(
468 | in_features=dim,
469 | hidden_features=mlp_hidden_dim,
470 | act_layer=act_layer,
471 | drop=drop,
472 | )
473 |
474 | if init_values is not None:
475 | self.gamma_1 = nn.Parameter(
476 | init_values * torch.ones((dim)), requires_grad=True
477 | )
478 | self.gamma_2 = nn.Parameter(
479 | init_values * torch.ones((dim)), requires_grad=True
480 | )
481 | else:
482 | self.gamma_1, self.gamma_2 = None, None
483 |
484 | def forward(self, x, rel_pos_bias=None):
485 | if self.gamma_1 is None:
486 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
487 | x = x + self.drop_path(self.mlp(self.norm2(x)))
488 | else:
489 | x = x + self.drop_path(
490 | self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
491 | )
492 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
493 | return x
494 |
495 |
496 | class PatchEmbed(nn.Module):
497 | """Image to Patch Embedding"""
498 |
499 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
500 | super().__init__()
501 | img_size = to_2tuple(img_size)
502 | patch_size = to_2tuple(patch_size)
503 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
504 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
505 | self.img_size = img_size
506 | self.patch_size = patch_size
507 | self.num_patches = num_patches
508 |
509 | self.proj = nn.Conv2d(
510 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
511 | )
512 |
513 | def forward(self, x, **kwargs):
514 | B, C, H, W = x.shape
515 | # FIXME look at relaxing size constraints
516 | # assert H == self.img_size[0] and W == self.img_size[1], \
517 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
518 | x = self.proj(x)
519 | Hp, Wp = x.shape[2], x.shape[3]
520 |
521 | x = x.flatten(2).transpose(1, 2)
522 | return x, (Hp, Wp)
523 |
524 |
525 | class HybridEmbed(nn.Module):
526 | """CNN Feature Map Embedding
527 | Extract feature map from CNN, flatten, project to embedding dim.
528 | """
529 |
530 | def __init__(
531 | self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768
532 | ):
533 | super().__init__()
534 | assert isinstance(backbone, nn.Module)
535 | img_size = to_2tuple(img_size)
536 | self.img_size = img_size
537 | self.backbone = backbone
538 | if feature_size is None:
539 | with torch.no_grad():
540 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
541 | # map for all networks, the feature metadata has reliable channel and stride info, but using
542 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
543 | training = backbone.training
544 | if training:
545 | backbone.eval()
546 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[
547 | -1
548 | ]
549 | feature_size = o.shape[-2:]
550 | feature_dim = o.shape[1]
551 | backbone.train(training)
552 | else:
553 | feature_size = to_2tuple(feature_size)
554 | feature_dim = self.backbone.feature_info.channels()[-1]
555 | self.num_patches = feature_size[0] * feature_size[1]
556 | self.proj = nn.Linear(feature_dim, embed_dim)
557 |
558 | def forward(self, x):
559 | x = self.backbone(x)[-1]
560 | x = x.flatten(2).transpose(1, 2)
561 | x = self.proj(x)
562 | return x
563 |
564 |
565 | class RelativePositionBias(nn.Module):
566 | def __init__(self, window_size, num_heads):
567 | super().__init__()
568 | self.window_size = window_size
569 | self.num_relative_distance = (2 * window_size[0] - 1) * (
570 | 2 * window_size[1] - 1
571 | ) + 3
572 | self.relative_position_bias_table = nn.Parameter(
573 | torch.zeros(self.num_relative_distance, num_heads)
574 | ) # 2*Wh-1 * 2*Ww-1, nH
575 | # cls to token & token 2 cls & cls to cls
576 |
577 | # get pair-wise relative position index for each token inside the window
578 | coords_h = torch.arange(window_size[0])
579 | coords_w = torch.arange(window_size[1])
580 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
581 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
582 | relative_coords = (
583 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
584 | ) # 2, Wh*Ww, Wh*Ww
585 | relative_coords = relative_coords.permute(
586 | 1, 2, 0
587 | ).contiguous() # Wh*Ww, Wh*Ww, 2
588 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
589 | relative_coords[:, :, 1] += window_size[1] - 1
590 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
591 | relative_position_index = torch.zeros(
592 | size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
593 | )
594 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
595 | relative_position_index[0, 0:] = self.num_relative_distance - 3
596 | relative_position_index[0:, 0] = self.num_relative_distance - 2
597 | relative_position_index[0, 0] = self.num_relative_distance - 1
598 |
599 | self.register_buffer("relative_position_index", relative_position_index)
600 |
601 | # trunc_normal_(self.relative_position_bias_table, std=.02)
602 |
603 | def forward(self):
604 | relative_position_bias = self.relative_position_bias_table[
605 | self.relative_position_index.view(-1)
606 | ].view(
607 | self.window_size[0] * self.window_size[1] + 1,
608 | self.window_size[0] * self.window_size[1] + 1,
609 | -1,
610 | ) # Wh*Ww,Wh*Ww,nH
611 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
612 |
613 |
614 | @BACKBONES.register_module()
615 | class EVA2(nn.Module):
616 | """Vision Transformer with support for patch or hybrid CNN input stage"""
617 |
618 | def __init__(
619 | self,
620 | img_size=224,
621 | patch_size=16,
622 | in_chans=3,
623 | num_classes=80,
624 | embed_dim=768,
625 | depth=12,
626 | num_heads=12,
627 | mlp_ratio=4 * 2 / 3, # GLU default
628 | qkv_bias=False,
629 | qk_scale=None,
630 | drop_rate=0.0,
631 | attn_drop_rate=0.0,
632 | drop_path_rate=0.0,
633 | hybrid_backbone=None,
634 | norm_layer=None,
635 | init_values=None,
636 | use_checkpoint=False,
637 | use_abs_pos_emb=True,
638 | use_rel_pos_bias=False,
639 | use_shared_rel_pos_bias=False,
640 | out_indices=[3, 5, 7, 11],
641 | subln=True,
642 | xattn=True,
643 | naiveswiglu=True,
644 | rope=True,
645 | pt_hw_seq_len=16,
646 | intp_freq=True,
647 | pretrained=None,
648 | ):
649 | super().__init__()
650 | # norm_layer = norm_layer or partial(FusedLayerNorm, eps=1e-6)
651 | self.num_classes = num_classes
652 | self.num_features = (
653 | self.embed_dim
654 | ) = embed_dim # num_features for consistency with other models
655 |
656 | if hybrid_backbone is not None:
657 | self.patch_embed = HybridEmbed(
658 | hybrid_backbone,
659 | img_size=img_size,
660 | in_chans=in_chans,
661 | embed_dim=embed_dim,
662 | )
663 | else:
664 | self.patch_embed = PatchEmbed(
665 | img_size=img_size,
666 | patch_size=patch_size,
667 | in_chans=in_chans,
668 | embed_dim=embed_dim,
669 | )
670 |
671 | num_patches = self.patch_embed.num_patches
672 | self.out_indices = out_indices
673 |
674 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
675 |
676 | if use_abs_pos_emb:
677 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
678 | else:
679 | self.pos_embed = None
680 |
681 | self.pos_drop = nn.Dropout(p=drop_rate)
682 |
683 | if use_shared_rel_pos_bias:
684 | self.rel_pos_bias = RelativePositionBias(
685 | window_size=self.patch_embed.patch_shape, num_heads=num_heads
686 | )
687 | else:
688 | self.rel_pos_bias = None
689 |
690 | if rope:
691 | half_head_dim = embed_dim // num_heads // 2
692 | hw_seq_len = img_size // patch_size
693 | self.rope = VisionRotaryEmbeddingFast(
694 | dim=half_head_dim,
695 | pt_seq_len=pt_hw_seq_len,
696 | ft_seq_len=hw_seq_len if intp_freq else None,
697 | )
698 | else:
699 | self.rope = None
700 |
701 | self.naiveswiglu = naiveswiglu
702 |
703 | dpr = [
704 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
705 | ] # stochastic depth decay rule
706 | self.use_rel_pos_bias = use_rel_pos_bias
707 | self.use_checkpoint = use_checkpoint
708 | self.blocks = nn.ModuleList(
709 | [
710 | Block(
711 | dim=embed_dim,
712 | num_heads=num_heads,
713 | mlp_ratio=mlp_ratio,
714 | qkv_bias=qkv_bias,
715 | qk_scale=qk_scale,
716 | drop=drop_rate,
717 | attn_drop=attn_drop_rate,
718 | drop_path=dpr[i],
719 | norm_layer=norm_layer,
720 | init_values=init_values,
721 | window_size=self.patch_embed.patch_shape
722 | if use_rel_pos_bias
723 | else None,
724 | subln=subln,
725 | xattn=xattn,
726 | naiveswiglu=naiveswiglu,
727 | rope=self.rope,
728 | )
729 | for i in range(depth)
730 | ]
731 | )
732 |
733 | if self.pos_embed is not None:
734 | trunc_normal_(self.pos_embed, std=0.02)
735 | trunc_normal_(self.cls_token, std=0.02)
736 |
737 | # if patch_size == 16:
738 | # self.fpn1 = nn.Sequential(
739 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
740 | # nn.SyncBatchNorm(embed_dim),
741 | # nn.GELU(),
742 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
743 | # )
744 |
745 | # self.fpn2 = nn.Sequential(
746 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
747 | # )
748 |
749 | # self.fpn3 = nn.Identity()
750 |
751 | # self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
752 | # elif patch_size == 8:
753 | # self.fpn1 = nn.Sequential(
754 | # nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
755 | # )
756 |
757 | # self.fpn2 = nn.Identity()
758 |
759 | # self.fpn3 = nn.Sequential(
760 | # nn.MaxPool2d(kernel_size=2, stride=2),
761 | # )
762 |
763 | # self.fpn4 = nn.Sequential(
764 | # nn.MaxPool2d(kernel_size=4, stride=4),
765 | # )
766 | # self.init_weights(pretrained)
767 | self.pretrained = pretrained
768 |
769 | def _init_weights(self, m):
770 | if isinstance(m, nn.Linear):
771 | trunc_normal_(m.weight, std=0.02)
772 | if isinstance(m, nn.Linear) and m.bias is not None:
773 | nn.init.constant_(m.bias, 0)
774 | elif isinstance(m, nn.LayerNorm):
775 | nn.init.constant_(m.bias, 0)
776 | nn.init.constant_(m.weight, 1.0)
777 |
778 | def init_weights(self):
779 | """Initialize the weights in backbone.
780 |
781 | Args:
782 | pretrained (str, optional): Path to pre-trained weights.
783 | Defaults to None.
784 | """
785 | pretrained = self.pretrained
786 |
787 | def _init_weights(m):
788 | if isinstance(m, nn.Linear):
789 | trunc_normal_(m.weight, std=0.02)
790 | if isinstance(m, nn.Linear) and m.bias is not None:
791 | nn.init.constant_(m.bias, 0)
792 | elif isinstance(m, nn.LayerNorm):
793 | nn.init.constant_(m.bias, 0)
794 | nn.init.constant_(m.weight, 1.0)
795 |
796 | if isinstance(pretrained, str):
797 | self.apply(_init_weights)
798 | logger = MMLogger.get_current_instance()
799 | load_checkpoint(self, pretrained, strict=False, logger=logger)
800 | elif pretrained is None:
801 | self.apply(_init_weights)
802 | else:
803 | raise TypeError("pretrained must be a str or None")
804 |
805 | def get_num_layers(self):
806 | return len(self.blocks)
807 |
808 | @torch.jit.ignore
809 | def no_weight_decay(self):
810 | return {"pos_embed", "cls_token"}
811 |
812 | def forward_features(self, x):
813 | B, C, H, W = x.shape
814 | x, (Hp, Wp) = self.patch_embed(x)
815 | batch_size, seq_len, _ = x.size()
816 |
817 | cls_tokens = self.cls_token.expand(
818 | batch_size, -1, -1
819 | ) # stole cls_tokens impl from Phil Wang, thanks
820 | x = torch.cat((cls_tokens, x), dim=1)
821 | if self.pos_embed is not None:
822 | x = x + self.pos_embed
823 | x = self.pos_drop(x)
824 |
825 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
826 | features = []
827 | for i, blk in enumerate(self.blocks):
828 | if self.use_checkpoint:
829 | x = checkpoint.checkpoint(blk, x, rel_pos_bias)
830 | else:
831 | x = blk(x, rel_pos_bias)
832 | if i in self.out_indices:
833 | xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
834 | features.append(xp.contiguous())
835 | features[0] = F.interpolate(
836 | features[0], scale_factor=4, mode="bilinear", align_corners=False
837 | )
838 | features[1] = F.interpolate(
839 | features[1], scale_factor=2, mode="bilinear", align_corners=False
840 | )
841 | features[3] = F.interpolate(
842 | features[3], scale_factor=0.5, mode="bilinear", align_corners=False
843 | )
844 |
845 | return tuple(features)
846 |
847 | def forward(self, x):
848 | x = self.forward_features(x)
849 | return x
850 |
--------------------------------------------------------------------------------
/rein/models/backbones/reins.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from functools import reduce
6 | from operator import mul
7 | from torch import Tensor
8 |
9 | class Reins(nn.Module):
10 | def __init__(
11 | self,
12 | num_layers: int,
13 | embed_dims: int,
14 | patch_size: int,
15 | query_dims: int = 256,
16 | token_length: int = 100,
17 | use_softmax: bool = True,
18 | link_token_to_query: bool = True,
19 | scale_init: float = 0.001,
20 | zero_mlp_delta_f: bool = False,
21 | ) -> None:
22 | super().__init__()
23 | self.num_layers = num_layers
24 | self.embed_dims = embed_dims
25 | self.patch_size = patch_size
26 | self.query_dims = query_dims
27 | self.token_length = token_length
28 | self.link_token_to_query = link_token_to_query
29 | self.scale_init = scale_init
30 | self.use_softmax = use_softmax
31 | self.zero_mlp_delta_f = zero_mlp_delta_f
32 | self.create_model()
33 |
34 | def create_model(self):
35 | self.learnable_tokens = nn.Parameter(
36 | torch.empty([self.num_layers, self.token_length, self.embed_dims])
37 | )
38 | self.scale = nn.Parameter(torch.tensor(self.scale_init))
39 | self.mlp_token2feat = nn.Linear(self.embed_dims, self.embed_dims)
40 | self.mlp_delta_f = nn.Linear(self.embed_dims, self.embed_dims)
41 | val = math.sqrt(
42 | 6.0
43 | / float(
44 | 3 * reduce(mul, (self.patch_size, self.patch_size), 1) + self.embed_dims
45 | )
46 | )
47 | nn.init.uniform_(self.learnable_tokens.data, -val, val)
48 | nn.init.kaiming_uniform_(self.mlp_delta_f.weight, a=math.sqrt(5))
49 | nn.init.kaiming_uniform_(self.mlp_token2feat.weight, a=math.sqrt(5))
50 | self.transform = nn.Linear(self.embed_dims, self.query_dims)
51 | self.merge = nn.Linear(self.query_dims * 3, self.query_dims)
52 | if self.zero_mlp_delta_f:
53 | del self.scale
54 | self.scale = 1.0
55 | nn.init.zeros_(self.mlp_delta_f.weight)
56 | nn.init.zeros_(self.mlp_delta_f.bias)
57 |
58 | def return_auto(self, feats):
59 | if self.link_token_to_query:
60 | tokens = self.transform(self.get_tokens(-1)).permute(1, 2, 0)
61 | tokens = torch.cat(
62 | [
63 | F.max_pool1d(tokens, kernel_size=self.num_layers),
64 | F.avg_pool1d(tokens, kernel_size=self.num_layers),
65 | tokens[:, :, -1].unsqueeze(-1),
66 | ],
67 | dim=-1,
68 | )
69 | querys = self.merge(tokens.flatten(-2, -1))
70 | return feats, querys
71 | else:
72 | return feats
73 |
74 | def get_tokens(self, layer: int) -> Tensor:
75 | if layer == -1:
76 | # return all
77 | return self.learnable_tokens
78 | else:
79 | return self.learnable_tokens[layer]
80 |
81 | def forward(
82 | self, feats: Tensor, layer: int, batch_first=False, has_cls_token=True
83 | ) -> Tensor:
84 | if batch_first:
85 | feats = feats.permute(1, 0, 2)
86 | if has_cls_token:
87 | cls_token, feats = torch.tensor_split(feats, [1], dim=0)
88 | tokens = self.get_tokens(layer)
89 | delta_feat = self.forward_delta_feat(
90 | feats,
91 | tokens,
92 | layer,
93 | )
94 | delta_feat = delta_feat * self.scale
95 | feats = feats + delta_feat
96 | if has_cls_token:
97 | feats = torch.cat([cls_token, feats], dim=0)
98 | if batch_first:
99 | feats = feats.permute(1, 0, 2)
100 | return feats
101 |
102 | def forward_delta_feat(self, feats: Tensor, tokens: Tensor, layers: int) -> Tensor:
103 | attn = torch.einsum("nbc,mc->nbm", feats, tokens)
104 | if self.use_softmax:
105 | attn = attn * (self.embed_dims**-0.5)
106 | attn = F.softmax(attn, dim=-1)
107 | delta_f = torch.einsum(
108 | "nbm,mc->nbc",
109 | attn[:, :, 1:],
110 | self.mlp_token2feat(tokens[1:, :]),
111 | )
112 | delta_f = self.mlp_delta_f(delta_f + feats)
113 | return delta_f
--------------------------------------------------------------------------------
/rein/models/backbones/reins_dinov2.py:
--------------------------------------------------------------------------------
1 | from .reins import Reins
2 | from .dino_v2 import DinoVisionTransformer
3 | from .utils import set_requires_grad, set_train
4 |
5 |
6 | class ReinsDinoVisionTransformer(DinoVisionTransformer):
7 | def __init__(
8 | self,
9 | **kwargs,
10 | ):
11 | super().__init__(**kwargs)
12 | self.reins = Reins(
13 | num_layers = kwargs['depth'],
14 | embed_dims = kwargs['embed_dim'],
15 | patch_size = kwargs['patch_size'],
16 | )
17 |
18 | # self.reins2 = Reins(
19 | # num_layers = kwargs['depth'],
20 | # embed_dims = kwargs['embed_dim'],
21 | # patch_size = kwargs['patch_size'],
22 | # )
23 |
24 | def forward_features(self, x, masks=None):
25 | B, _, h, w = x.shape
26 | H, W = h // self.patch_size, w // self.patch_size
27 | x = self.prepare_tokens_with_masks(x, masks)
28 | outs = []
29 |
30 | for idx, blk in enumerate(self.blocks):
31 | x = blk(x)
32 | x = self.reins.forward(
33 | x,
34 | idx,
35 | batch_first=True,
36 | has_cls_token=True,
37 | )
38 | return x
39 |
40 | def forward_features_full_rein(self, x, masks=None):
41 | B, _, h, w = x.shape
42 | H, W = h // self.patch_size, w // self.patch_size
43 | x = self.prepare_tokens_with_masks(x, masks)
44 | outs = []
45 | for idx, blk in enumerate(self.blocks):
46 | x = blk(x)
47 | x = self.reins.forward(
48 | x,
49 | idx,
50 | batch_first=True,
51 | has_cls_token=True,
52 | )
53 | if idx in self.out_indices:
54 | outs.append(
55 | x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, H, W).contiguous()
56 | )
57 | return self.reins.return_auto(outs)
58 |
59 |
60 |
61 | def forward_features_no_rein(self, x, masks=None):
62 | B, _, h, w = x.shape
63 | H, W = h // self.patch_size, w // self.patch_size
64 | x = self.prepare_tokens_with_masks(x, masks)
65 | outs = []
66 | for idx, blk in enumerate(self.blocks):
67 | x = blk(x)
68 | return x
69 |
70 | def train(self, mode: bool = True):
71 | if not mode:
72 | return super().train(mode)
73 | set_requires_grad(self, ["reins", "linear"])
74 | set_train(self, ["reins", "linear"])
75 |
76 |
77 |
78 | class ReinsDinoVisionTransformer_3_head(DinoVisionTransformer):
79 | def __init__(
80 | self,
81 | **kwargs,
82 | ):
83 | super().__init__(**kwargs)
84 | self.reins1 = Reins(
85 | num_layers = kwargs['depth'],
86 | embed_dims = kwargs['embed_dim'],
87 | patch_size = kwargs['patch_size'],
88 | )
89 |
90 | self.reins2 = Reins(
91 | num_layers = kwargs['depth'],
92 | embed_dims = kwargs['embed_dim'],
93 | patch_size = kwargs['patch_size'],
94 | )
95 |
96 | def forward_features1(self, x, masks=None):
97 | B, _, h, w = x.shape
98 | H, W = h // self.patch_size, w // self.patch_size
99 | x = self.prepare_tokens_with_masks(x, masks)
100 | outs = []
101 |
102 | for idx, blk in enumerate(self.blocks):
103 | x = blk(x)
104 | x = self.reins1.forward(
105 | x,
106 | idx,
107 | batch_first=True,
108 | has_cls_token=True,
109 | )
110 | return x
111 |
112 | def forward_features2(self, x, masks=None):
113 | B, _, h, w = x.shape
114 | H, W = h // self.patch_size, w // self.patch_size
115 | x = self.prepare_tokens_with_masks(x, masks)
116 | outs = []
117 |
118 | for idx, blk in enumerate(self.blocks):
119 | x = blk(x)
120 | x = self.reins2.forward(
121 | x,
122 | idx,
123 | batch_first=True,
124 | has_cls_token=True,
125 | )
126 | return x
127 |
128 | def forward_features_no_rein(self, x, masks=None):
129 | B, _, h, w = x.shape
130 | H, W = h // self.patch_size, w // self.patch_size
131 | x = self.prepare_tokens_with_masks(x, masks)
132 | outs = []
133 | for idx, blk in enumerate(self.blocks):
134 | x = blk(x)
135 | return x
136 |
137 | def train(self, mode: bool = True):
138 | if not mode:
139 | return super().train(mode)
140 | set_requires_grad(self, ["reins1", "reins2", "linear"])
141 | set_train(self, ["reins1", "reins2", "linear"])
--------------------------------------------------------------------------------
/rein/models/backbones/reins_eva_02.py:
--------------------------------------------------------------------------------
1 | from .eva_02 import EVA2
2 | from mmseg.models.builder import BACKBONES, MODELS
3 | from .reins import Reins
4 | import torch
5 | import torch.utils.checkpoint as checkpoint
6 | import torch.nn.functional as F
7 | from .utils import set_requires_grad, set_train
8 |
9 |
10 | @BACKBONES.register_module()
11 | class ReinsEVA2(EVA2):
12 | def __init__(self, reins_config=None, **kwargs):
13 | super().__init__(**kwargs)
14 | self.reins: Reins = MODELS.build(reins_config)
15 |
16 | def forward_features(self, x):
17 | B, C, H, W = x.shape
18 | x, (Hp, Wp) = self.patch_embed(x)
19 | batch_size, seq_len, _ = x.size()
20 |
21 | cls_tokens = self.cls_token.expand(
22 | batch_size, -1, -1
23 | ) # stole cls_tokens impl from Phil Wang, thanks
24 | x = torch.cat((cls_tokens, x), dim=1)
25 | if self.pos_embed is not None:
26 | x = x + self.pos_embed
27 | x = self.pos_drop(x)
28 |
29 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
30 | features = []
31 | for i, blk in enumerate(self.blocks):
32 | if self.use_checkpoint:
33 | x = checkpoint.checkpoint(blk, x, rel_pos_bias)
34 | else:
35 | x = blk(x, rel_pos_bias)
36 | x = self.reins.forward(
37 | x,
38 | i,
39 | batch_first=True,
40 | has_cls_token=True,
41 | )
42 | if i in self.out_indices:
43 | xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
44 | features.append(xp.contiguous())
45 | features[0] = F.interpolate(
46 | features[0], scale_factor=4, mode="bilinear", align_corners=False
47 | )
48 | features[1] = F.interpolate(
49 | features[1], scale_factor=2, mode="bilinear", align_corners=False
50 | )
51 | features[3] = F.interpolate(
52 | features[3], scale_factor=0.5, mode="bilinear", align_corners=False
53 | )
54 | return self.reins.return_auto(features)
55 |
56 | def train(self, mode: bool = True):
57 | if not mode:
58 | return super().train(mode)
59 | set_requires_grad(self, ["reins"])
60 | set_train(self, ["reins"])
61 |
62 | def state_dict(self, destination, prefix, keep_vars):
63 | state = super().state_dict(destination, prefix, keep_vars)
64 | keys = [k for k in state.keys() if "rein" not in k]
65 | for key in keys:
66 | state.pop(key)
67 | if key in destination:
68 | destination.pop(key)
69 | return state
70 |
--------------------------------------------------------------------------------
/rein/models/backbones/reins_resnet.py:
--------------------------------------------------------------------------------
1 | from .reins import Reins
2 | from .utils import set_requires_grad, set_train
3 | from typing import List, Dict
4 | import torch.nn as nn
5 | import timm
6 | from timm.models.resnet import ResNet, Bottleneck
7 |
8 | # Modified from the code of https://github.com/w1oves/Rein/blob/train/rein/models/backbones/reins_resnet.py
9 | class ReinsResNet(ResNet):
10 | def __init__(
11 | self,
12 | **kwargs,
13 | ):
14 | model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3])
15 | super().__init__(**dict(model_args, **kwargs))
16 | self.reins: List[Reins] = nn.ModuleList()
17 | self.reins.append(Reins(num_layers=1, embed_dims=256, patch_size=1)) # For layer 1
18 | self.reins.append(Reins(num_layers=1, embed_dims=512, patch_size=1)) # For layer 1
19 | self.reins.append(Reins(num_layers=1, embed_dims=1024, patch_size=1)) # For layer 1
20 | self.reins.append(Reins(num_layers=1, embed_dims=2048, patch_size=1)) # For layer 1
21 |
22 |
23 | print('length of reins: ', len(self.reins))
24 | def forward(self, x):
25 | x = self.conv1(x)
26 | x = self.bn1(x)
27 | x = self.act1(x)
28 | x = self.maxpool(x)
29 | outs = []
30 | for i, layer_name in enumerate(['layer1', 'layer2', 'layer3', 'layer4']):
31 | res_layer = getattr(self, layer_name)
32 | # print(res_layer)
33 | x = res_layer(x)
34 | # print(x.shape)
35 | B, C, H, W = x.shape
36 | x = (
37 | self.reins[i]
38 | .forward(
39 | x.flatten(-2, -1).permute(0, 2, 1),
40 | 0,
41 | batch_first=True,
42 | has_cls_token=False,
43 | )
44 | .permute(0, 2, 1)
45 | .reshape(B, C, H, W)
46 | )
47 | x = self.global_pool(x)
48 | x = self.fc(x)
49 | return x
50 |
51 | def train(self, mode: bool = True):
52 | if not mode:
53 | return super().train(mode)
54 | set_requires_grad(self, ["reins", "fc"])
55 | set_train(self, ["reins", "fc"])
56 |
--------------------------------------------------------------------------------
/rein/models/backbones/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import List
3 | # from mmengine.logging import MMLogger
4 |
5 | first_set_requires_grad = True
6 | first_set_train = True
7 |
8 |
9 | def set_requires_grad(model: nn.Module, keywords: List[str]):
10 | """
11 | notice:key in name!
12 | """
13 | requires_grad_names = []
14 | num_params = 0
15 | num_trainable = 0
16 | for name, param in model.named_parameters():
17 | num_params += param.numel()
18 | if any(key in name for key in keywords):
19 | param.requires_grad = True
20 | requires_grad_names.append(name)
21 | num_trainable += param.numel()
22 | else:
23 | param.requires_grad = False
24 | global first_set_requires_grad
25 | # if first_set_requires_grad:
26 | # # logger = MMLogger.get_current_instance()
27 | # for name in requires_grad_names:
28 | # logger.info(f"set_requires_grad----{name}")
29 | # logger.info(
30 | # f"Total trainable params--{num_trainable}, All params--{num_params}, Ratio--{num_trainable*100/num_params:.1f}%"
31 | # )
32 | # first_set_requires_grad = False
33 |
34 |
35 | def _set_train(model: nn.Module, keywords: List[str], prefix: str = ""):
36 | train_names = []
37 | for name, child in model.named_children():
38 | fullname = ".".join([prefix, name])
39 | if any(name.startswith(key) for key in keywords):
40 | train_names.append(fullname)
41 | child.train()
42 | else:
43 | train_names += _set_train(child, keywords, prefix=fullname)
44 | return train_names
45 |
46 |
47 | def set_train(model: nn.Module, keywords: List[str]):
48 | """
49 | notice:sub name startwith key!
50 | """
51 | model.train(False)
52 | train_names = _set_train(model, keywords)
53 | # global first_set_train
54 | # if first_set_train:
55 | # logger = MMLogger.get_current_instance()
56 | # for train_name in train_names:
57 | # logger.info(f"set_train----{train_name}")
58 | # first_set_train = False
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | medmnist==3.0.1
2 | numpy==1.24.4
3 | nvidia-cublas-cu11==11.10.3.66
4 | nvidia-cuda-cupti-cu11==11.7.101
5 | nvidia-cuda-nvrtc-cu11==11.7.99
6 | nvidia-cuda-runtime-cu11==11.7.99
7 | nvidia-cudnn-cu11==8.5.0.96
8 | nvidia-cufft-cu11==10.9.0.58
9 | nvidia-curand-cu11==10.2.10.91
10 | nvidia-cusolver-cu11==11.4.0.1
11 | nvidia-cusparse-cu11==11.7.4.91
12 | nvidia-nccl-cu11==2.14.3
13 | nvidia-nvtx-cu11==11.7.91
14 | Pillow==10.0.0
15 | timm==0.6.13
16 | torch==2.0.1
17 | torchaudio==2.0.2
18 | torchdiffeq==0.2.3
19 | torchvision==0.15.2
--------------------------------------------------------------------------------
/train_cufit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | import argparse
6 | import timm
7 | import utils
8 |
9 | import rein
10 |
11 | import dino_variant
12 |
13 |
14 | def train():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data', '-d', type=str)
17 | parser.add_argument('--gpu', '-g', default = '0', type=str)
18 | parser.add_argument('--netsize', default='s', type=str)
19 | parser.add_argument('--save_path', '-s', type=str)
20 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2)
21 | args = parser.parse_args()
22 |
23 | config = utils.read_conf('conf/'+args.data+'.json')
24 | device = 'cuda:'+args.gpu
25 | save_path = os.path.join(config['save_path'], args.save_path)
26 | data_path = config['id_dataset']
27 | batch_size = int(config['batch_size'])
28 | max_epoch = int(config['epoch'])
29 | noise_rate = args.noise_rate
30 |
31 | if not os.path.exists(save_path):
32 | os.mkdir(save_path)
33 |
34 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)]
35 |
36 | if args.data == 'ham10000':
37 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
38 | elif args.data == 'aptos':
39 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
40 | elif 'mnist' in args.data:
41 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size)
42 | elif 'cifar' in args.data:
43 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate)
44 |
45 | if args.netsize == 's':
46 | model_load = dino_variant._small_dino
47 | variant = dino_variant._small_variant
48 | elif args.netsize == 'b':
49 | model_load = dino_variant._base_dino
50 | variant = dino_variant._base_variant
51 | elif args.netsize == 'l':
52 | model_load = dino_variant._large_dino
53 | variant = dino_variant._large_variant
54 | # model = timm.create_model(network, pretrained=True, num_classes=2)
55 | model = torch.hub.load('facebookresearch/dinov2', model_load)
56 | dino_state_dict = model.state_dict()
57 |
58 | model = rein.ReinsDinoVisionTransformer(
59 | **variant
60 | )
61 | model.load_state_dict(dino_state_dict, strict=False)
62 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes'])
63 | model.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes'])
64 | model.to(device)
65 | criterion = torch.nn.CrossEntropyLoss(reduction='none')
66 | model.eval()
67 |
68 | model2 = rein.ReinsDinoVisionTransformer(
69 | **variant
70 | )
71 | model2.load_state_dict(dino_state_dict, strict=False)
72 | model2.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes'])
73 | model2.to(device)
74 |
75 | model.eval()
76 | model2.eval()
77 |
78 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-5)
79 | optimizer2 = torch.optim.Adam(model2.parameters(), lr=1e-3, weight_decay = 1e-5)
80 |
81 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay)
82 | scheduler2 = torch.optim.lr_scheduler.MultiStepLR(optimizer2, lr_decay)
83 | saver = timm.utils.CheckpointSaver(model2, optimizer, checkpoint_dir= save_path, max_history = 1)
84 | print(train_loader.dataset[0][0].shape)
85 |
86 | print('## Trainable parameters')
87 | model2.train()
88 | for n, p in model2.named_parameters():
89 | if p.requires_grad == True:
90 | print(n)
91 |
92 | avg_accuracy = 0.0
93 | for epoch in range(max_epoch):
94 | ## training
95 | model.train()
96 | model2.train()
97 | total_loss = 0
98 | total = 0
99 | correct = 0
100 | correct2 = 0
101 | correct_linear = 0
102 | for batch_idx, (inputs, targets) in enumerate(train_loader):
103 | inputs, targets = inputs.to(device), targets.to(device)
104 | optimizer.zero_grad()
105 |
106 | features_rein = model.forward_features(inputs)
107 | features_rein = features_rein[:, 0, :]
108 | outputs = model.linear_rein(features_rein)
109 |
110 | features_rein2 = model2.forward_features(inputs)
111 | features_rein2 = features_rein2[:, 0, :]
112 | outputs2 = model2.linear_rein(features_rein2)
113 |
114 | with torch.no_grad():
115 | features_ = model.forward_features_no_rein(inputs)
116 | features_ = features_[:, 0, :]
117 | outputs_ = model.linear(features_)
118 | # print(outputs.shape, outputs_.shape)
119 |
120 | with torch.no_grad():
121 | pred = (outputs_).max(1).indices
122 | linear_accurate = (pred==targets)
123 |
124 | pred2 = outputs.max(1).indices
125 | linear_accurate2 = (pred2==targets)
126 |
127 | loss_rein = linear_accurate*criterion(outputs, targets)
128 | loss_rein2 = linear_accurate2*criterion(outputs2, targets)
129 | loss_linear = criterion(outputs_, targets)
130 | loss = loss_linear.mean()+loss_rein.mean()
131 | loss.backward()
132 | optimizer.step() # + outputs_
133 |
134 | optimizer2.zero_grad()
135 | loss_rein2.mean().backward()
136 | optimizer2.step()
137 |
138 | total_loss += loss
139 | total += targets.size(0)
140 | _, predicted = outputs[:len(targets)].max(1)
141 | correct += predicted.eq(targets).sum().item()
142 |
143 | _, predicted = outputs2[:len(targets)].max(1)
144 | correct2 += predicted.eq(targets).sum().item()
145 |
146 | _, predicted = outputs_[:len(targets)].max(1)
147 | correct_linear += predicted.eq(targets).sum().item()
148 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc2: %.3f%% | Acc1: %.3f%% | LinearAcc: %.3f%% | (%d/%d)'
149 | % (total_loss/(batch_idx+1), 100.*correct2/total, 100.*correct/total, 100.*correct_linear/total, correct, total), end = '')
150 | train_accuracy = correct/total
151 | train_avg_loss = total_loss/len(train_loader)
152 | print()
153 |
154 | ## validation
155 | model.eval()
156 | model2.eval()
157 |
158 | total_loss = 0
159 | total = 0
160 | correct = 0
161 | valid_accuracy = utils.validation_accuracy(model2, valid_loader, device)
162 | valid_accuracy_ = utils.validation_accuracy(model, valid_loader, device)
163 | valid_accuracy_linear = utils.validation_accuracy(model, valid_loader, device, mode='no_rein')
164 |
165 | scheduler.step()
166 | scheduler2.step()
167 | if epoch >= max_epoch-10:
168 | avg_accuracy += valid_accuracy
169 | saver.save_checkpoint(epoch, metric = valid_accuracy)
170 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID_2 [acc - {:.4f}], VALID_1 [acc - {:.4f}], VALID(linear) [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy, valid_accuracy_, valid_accuracy_linear))
171 | print(scheduler.get_last_lr())
172 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f:
173 | f.write(str(avg_accuracy/10))
174 | if __name__ =='__main__':
175 | train()
--------------------------------------------------------------------------------
/train_fully.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | import argparse
6 | import timm
7 | import utils
8 |
9 | import dino_variant
10 |
11 |
12 | def train():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--data', '-d', type=str)
15 | parser.add_argument('--gpu', '-g', default = '0', type=str)
16 | parser.add_argument('--netsize', default='s', type=str)
17 | parser.add_argument('--save_path', '-s', type=str)
18 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2)
19 | args = parser.parse_args()
20 |
21 | config = utils.read_conf('conf/'+args.data+'.json')
22 | device = 'cuda:'+args.gpu
23 | save_path = os.path.join(config['save_path'], args.save_path)
24 | data_path = config['id_dataset']
25 | batch_size = int(config['batch_size'])
26 | max_epoch = int(config['epoch'])
27 | noise_rate = args.noise_rate
28 |
29 | if not os.path.exists(save_path):
30 | os.mkdir(save_path)
31 |
32 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)]
33 |
34 | if args.data == 'ham10000':
35 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
36 | elif args.data == 'aptos':
37 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
38 | elif 'mnist' in args.data:
39 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size)
40 | elif 'cifar' in args.data:
41 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate)
42 |
43 | if args.netsize == 's':
44 | model_load = dino_variant._small_dino
45 | variant = dino_variant._small_variant
46 | elif args.netsize == 'b':
47 | model_load = dino_variant._base_dino
48 | variant = dino_variant._base_variant
49 | elif args.netsize == 'l':
50 | model_load = dino_variant._large_dino
51 | variant = dino_variant._large_variant
52 |
53 | model = torch.hub.load('facebookresearch/dinov2', model_load)
54 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes'])
55 | model.to(device)
56 |
57 | criterion = torch.nn.CrossEntropyLoss()
58 | model.eval()
59 |
60 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay = 1e-5)
61 |
62 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay)
63 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1)
64 | print(train_loader.dataset[0][0].shape)
65 |
66 | print('## Trainable parameters')
67 | model.train()
68 | for n, p in model.named_parameters():
69 | if p.requires_grad == True:
70 | print(n)
71 |
72 | avg_accuracy = 0.0
73 | for epoch in range(max_epoch):
74 | ## training
75 | model.train()
76 | total_loss = 0
77 | total = 0
78 | correct = 0
79 | for batch_idx, (inputs, targets) in enumerate(train_loader):
80 | inputs, targets = inputs.to(device), targets.to(device)
81 | optimizer.zero_grad()
82 |
83 | outputs = model(inputs)
84 | outputs = model.linear(outputs)
85 |
86 | loss = criterion(outputs, targets)
87 | loss.backward()
88 | optimizer.step()
89 |
90 | total_loss += loss
91 | total += targets.size(0)
92 | _, predicted = outputs[:len(targets)].max(1)
93 | correct += predicted.eq(targets).sum().item()
94 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
95 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '')
96 | train_accuracy = correct/total
97 | train_avg_loss = total_loss/len(train_loader)
98 | print()
99 |
100 | ## validation
101 | model.eval()
102 | total_loss = 0
103 | total = 0
104 | correct = 0
105 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device, mode='linear')
106 | scheduler.step()
107 | if epoch >= max_epoch-10:
108 | avg_accuracy += valid_accuracy
109 | saver.save_checkpoint(epoch, metric = valid_accuracy)
110 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy))
111 | print(scheduler.get_last_lr())
112 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f:
113 | f.write(str(avg_accuracy/10))
114 |
115 | if __name__ =='__main__':
116 | train()
--------------------------------------------------------------------------------
/train_linear.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | import argparse
6 | import timm
7 | import utils
8 |
9 | import dino_variant
10 |
11 |
12 | def train():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--data', '-d', type=str)
15 | parser.add_argument('--gpu', '-g', default = '0', type=str)
16 | parser.add_argument('--netsize', default='s', type=str)
17 | parser.add_argument('--save_path', '-s', type=str)
18 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2)
19 | args = parser.parse_args()
20 |
21 | config = utils.read_conf('conf/'+args.data+'.json')
22 | device = 'cuda:'+args.gpu
23 | save_path = os.path.join(config['save_path'], args.save_path)
24 | data_path = config['id_dataset']
25 | batch_size = int(config['batch_size'])
26 | max_epoch = int(config['epoch'])
27 | noise_rate = args.noise_rate
28 |
29 | if not os.path.exists(save_path):
30 | os.mkdir(save_path)
31 |
32 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)]
33 |
34 |
35 | if args.data == 'ham10000':
36 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
37 | elif args.data == 'aptos':
38 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
39 | elif 'mnist' in args.data:
40 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size)
41 |
42 | if args.netsize == 's':
43 | model_load = dino_variant._small_dino
44 | variant = dino_variant._small_variant
45 | elif args.netsize == 'b':
46 | model_load = dino_variant._base_dino
47 | variant = dino_variant._base_variant
48 | elif args.netsize == 'l':
49 | model_load = dino_variant._large_dino
50 | variant = dino_variant._large_variant
51 |
52 | model = torch.hub.load('facebookresearch/dinov2', model_load)
53 | model.linear = nn.Linear(variant['embed_dim'], config['num_classes'])
54 | model.to(device)
55 |
56 | criterion = torch.nn.CrossEntropyLoss()
57 | model.eval()
58 |
59 | for n, p in model.named_parameters():
60 | if not 'linear' in n:
61 | p.requires_grad = False
62 | optimizer = torch.optim.Adam(model.linear.parameters(), lr=1e-3, weight_decay = 1e-5)
63 |
64 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay)
65 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1)
66 | print(train_loader.dataset[0][0].shape)
67 |
68 | print('## Trainable parameters')
69 | model.train()
70 | for n, p in model.named_parameters():
71 | if p.requires_grad == True:
72 | print(n)
73 | avg_accuracy = 0.0
74 | for epoch in range(max_epoch):
75 | ## training
76 | model.train()
77 | total_loss = 0
78 | total = 0
79 | correct = 0
80 | for batch_idx, (inputs, targets) in enumerate(train_loader):
81 | inputs, targets = inputs.to(device), targets.to(device)
82 | optimizer.zero_grad()
83 |
84 | with torch.no_grad():
85 | outputs = model(inputs)
86 | outputs = model.linear(outputs)
87 |
88 | loss = criterion(outputs, targets)
89 | loss.backward()
90 | optimizer.step()
91 |
92 | total_loss += loss
93 | total += targets.size(0)
94 | _, predicted = outputs[:len(targets)].max(1)
95 | correct += predicted.eq(targets).sum().item()
96 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
97 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '')
98 | train_accuracy = correct/total
99 |
100 | train_avg_loss = total_loss/len(train_loader)
101 | print()
102 |
103 | ## validation
104 | model.eval()
105 | total_loss = 0
106 | total = 0
107 | correct = 0
108 |
109 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device, mode='linear')
110 | if epoch >= max_epoch-10:
111 | avg_accuracy += valid_accuracy
112 | scheduler.step()
113 |
114 | saver.save_checkpoint(epoch, metric = valid_accuracy)
115 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy))
116 | print(scheduler.get_last_lr())
117 |
118 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f:
119 | f.write(str(avg_accuracy/10))
120 |
121 | if __name__ =='__main__':
122 | train()
--------------------------------------------------------------------------------
/train_rein.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | import argparse
6 | import timm
7 | import utils
8 |
9 | import rein
10 |
11 | import dino_variant
12 |
13 | def train():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data', '-d', type=str)
16 | parser.add_argument('--gpu', '-g', default = '0', type=str)
17 | parser.add_argument('--netsize', default='s', type=str)
18 | parser.add_argument('--save_path', '-s', type=str)
19 | parser.add_argument('--noise_rate', '-n', type=float, default=0.2)
20 | args = parser.parse_args()
21 |
22 | config = utils.read_conf('conf/'+args.data+'.json')
23 | device = 'cuda:'+args.gpu
24 | save_path = os.path.join(config['save_path'], args.save_path)
25 | data_path = config['id_dataset']
26 | batch_size = int(config['batch_size'])
27 | max_epoch = int(config['epoch'])
28 | noise_rate = args.noise_rate
29 |
30 | if not os.path.exists(save_path):
31 | os.mkdir(save_path)
32 |
33 | lr_decay = [int(0.5*max_epoch), int(0.75*max_epoch), int(0.9*max_epoch)]
34 |
35 |
36 | if args.data == 'ham10000':
37 | train_loader, valid_loader = utils.get_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
38 | elif args.data == 'aptos':
39 | train_loader, valid_loader = utils.get_aptos_noise_dataset(data_path, noise_rate=noise_rate, batch_size = batch_size)
40 | elif 'mnist' in args.data:
41 | train_loader, valid_loader = utils.get_mnist_noise_dataset(args.data, noise_rate=noise_rate, batch_size = batch_size)
42 | elif 'cifar' in args.data:
43 | train_loader, valid_loader = utils.get_cifar_noise_dataset(args.data, data_path, batch_size = batch_size, noise_rate=noise_rate)
44 |
45 | if args.netsize == 's':
46 | model_load = dino_variant._small_dino
47 | variant = dino_variant._small_variant
48 | elif args.netsize == 'b':
49 | model_load = dino_variant._base_dino
50 | variant = dino_variant._base_variant
51 | elif args.netsize == 'l':
52 | model_load = dino_variant._large_dino
53 | variant = dino_variant._large_variant
54 |
55 |
56 | model = torch.hub.load('facebookresearch/dinov2', model_load)
57 | dino_state_dict = model.state_dict()
58 |
59 | model = rein.ReinsDinoVisionTransformer(
60 | **variant
61 | )
62 | model.load_state_dict(dino_state_dict, strict=False)
63 | model.linear_rein = nn.Linear(variant['embed_dim'], config['num_classes'])
64 | model.to(device)
65 |
66 | criterion = torch.nn.CrossEntropyLoss()
67 | model.eval()
68 |
69 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay = 1e-5)
70 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, lr_decay)
71 | saver = timm.utils.CheckpointSaver(model, optimizer, checkpoint_dir= save_path, max_history = 1)
72 | print(train_loader.dataset[0][0].shape)
73 |
74 | print('## Trainable parameters')
75 | model.train()
76 | for n, p in model.named_parameters():
77 | if p.requires_grad == True:
78 | print(n)
79 |
80 | avg_accuracy = 0.0
81 | for epoch in range(max_epoch):
82 | ## training
83 | model.train()
84 | total_loss = 0
85 | total = 0
86 | correct = 0
87 | for batch_idx, (inputs, targets) in enumerate(train_loader):
88 | inputs, targets = inputs.to(device), targets.to(device)
89 | optimizer.zero_grad()
90 |
91 | features = model.forward_features(inputs)
92 | features = features[:, 0, :]
93 | outputs = model.linear_rein(features)
94 | loss = criterion(outputs, targets)
95 | loss.backward()
96 | optimizer.step()
97 |
98 | total_loss += loss
99 | total += targets.size(0)
100 | _, predicted = outputs[:len(targets)].max(1)
101 | correct += predicted.eq(targets).sum().item()
102 | print('\r', batch_idx, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
103 | % (total_loss/(batch_idx+1), 100.*correct/total, correct, total), end = '')
104 | train_accuracy = correct/total
105 |
106 | train_avg_loss = total_loss/len(train_loader)
107 | print()
108 |
109 | ## validation
110 | model.eval()
111 | total_loss = 0
112 | total = 0
113 | correct = 0
114 |
115 | valid_accuracy = utils.validation_accuracy(model, valid_loader, device)
116 | if epoch >= max_epoch-10:
117 | avg_accuracy += valid_accuracy
118 | scheduler.step()
119 |
120 | saver.save_checkpoint(epoch, metric = valid_accuracy)
121 | print('EPOCH {:4}, TRAIN [loss - {:.4f}, acc - {:.4f}], VALID [acc - {:.4f}]\n'.format(epoch, train_avg_loss, train_accuracy, valid_accuracy))
122 | print(scheduler.get_last_lr())
123 |
124 | with open(os.path.join(save_path, 'avgacc.txt'), 'w') as f:
125 | f.write(str(avg_accuracy/10))
126 |
127 | if __name__ =='__main__':
128 | train()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import *
2 | from .metric import *
--------------------------------------------------------------------------------
/utils/aptos.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os
3 |
4 | from PIL import Image
5 |
6 | class APTOS2019():
7 | def __init__(self, root_dir, train = True, transforms=None):
8 | """
9 | Arguments:
10 | csv_file (string): Path to the csv file with annotations.
11 | root_dir (string): Directory with all the images.
12 | transform (callable, optional): Optional transform to be applied
13 | on a sample.
14 | """
15 | self.root_dir = root_dir
16 | self.transform = transforms
17 |
18 | self.label_txt = os.path.join(root_dir, 'train_1.csv' if train else 'test.csv')
19 |
20 | self.samples = []
21 | with open(self.label_txt, 'r') as f:
22 | lines = f.readlines()
23 | for line in lines[1:]:
24 | line = line.split(',')
25 | if len(line) == 2:
26 | img_name, label = line
27 |
28 | img_name = os.path.join(root_dir, 'train_images/train_images' if train else 'test_images', img_name+'.png')
29 | label = label.replace('\n', '')
30 | label = int(label)
31 |
32 | self.samples.append([img_name, label])
33 |
34 | def __len__(self):
35 | return len(self.samples)
36 |
37 | def __getitem__(self, idx):
38 | sample, label = self.samples[idx]
39 | sample = Image.open(sample)
40 |
41 | if self.transform:
42 | sample = self.transform(sample)
43 |
44 | return sample, label
45 |
46 | if __name__ == '__main__':
47 | aptos = APTOS2019('./data/APTOS-2019', True)
48 | print(aptos[0])
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import torch
4 |
5 | from PIL import Image
6 | from torchvision import transforms
7 | from torchvision import datasets as dset
8 | import torchvision
9 |
10 | from .aptos import APTOS2019
11 |
12 | def get_transform(transform_type='default', image_size=224, args=None):
13 |
14 | if transform_type == 'default':
15 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
16 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
17 |
18 | mean = IMAGENET_DEFAULT_MEAN
19 | std = IMAGENET_DEFAULT_STD
20 |
21 |
22 | train_transform = transforms.Compose([
23 | transforms.Resize((256, 256)),
24 | # transforms.Resize((224, 224)),
25 | transforms.RandomHorizontalFlip(p=0.5),
26 | # transforms.RandomVerticalFlip(p=0.5),
27 | # transforms.ColorJitter(),
28 | transforms.RandomCrop(size=(image_size, image_size)),
29 | transforms.ToTensor(),
30 | transforms.Normalize(mean=mean, std=std)
31 | ])
32 |
33 | test_transform = transforms.Compose([
34 | transforms.Resize((256, 256)),
35 | transforms.CenterCrop((image_size, image_size)),
36 | transforms.ToTensor(),
37 | transforms.Normalize(mean=mean, std=std)
38 | ])
39 | return train_transform, test_transform
40 |
41 | def read_conf(json_path):
42 | """
43 | read json and return the configure as dictionary.
44 | """
45 | with open(json_path) as json_file:
46 | config = json.load(json_file)
47 | return config
48 |
49 | def get_noise_dataset(path, noise_rate = 0.2, batch_size = 32, seed = 0):
50 | train_transform, test_transform = get_transform()
51 | train_data = torchvision.datasets.ImageFolder(path + '/train', train_transform)
52 | np.random.seed(seed)
53 |
54 | new_data = []
55 | for i in range(len(train_data.samples)):
56 | if np.random.rand() > noise_rate: # clean sample:
57 | new_data.append([train_data.samples[i][0], train_data.samples[i][1]])
58 | else:
59 | label_index = list(range(7))
60 | label_index.remove(train_data.samples[i][1])
61 | label_index = np.array(label_index)
62 | label_index = np.reshape(label_index, (-1))
63 |
64 | new_label = np.random.choice(label_index, 1)
65 | new_label = new_label[0]
66 |
67 | new_data.append([train_data.samples[i][0], new_label])
68 | train_data.samples = new_data
69 |
70 | # Testing
71 | with open('label.txt', 'w') as f:
72 | for i in range(len(train_data.samples)):
73 | f.write('{}\n'.format(train_data.samples[i][1]))
74 |
75 | valid_data = torchvision.datasets.ImageFolder(path + '/test', test_transform)
76 |
77 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 8)
78 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8)
79 | return train_loader, valid_loader
80 |
81 | def get_aptos_noise_dataset(path, noise_rate = 0.2, batch_size = 32, seed = 0):
82 | train_transform, test_transform = get_transform()
83 | train_data = APTOS2019(path, train=True, transforms = train_transform)
84 |
85 | np.random.seed(seed)
86 | new_data = []
87 | for i in range(len(train_data.samples)):
88 | if np.random.rand() > noise_rate: # clean sample:
89 | new_data.append([train_data.samples[i][0], train_data.samples[i][1]])
90 | else:
91 | label_index = list(range(5))
92 | label_index.remove(train_data.samples[i][1])
93 | label_index = np.array(label_index)
94 | label_index = np.reshape(label_index, (-1))
95 |
96 | new_label = np.random.choice(label_index, 1)
97 | new_label = new_label[0]
98 |
99 | new_data.append([train_data.samples[i][0], new_label])
100 | train_data.samples = new_data
101 |
102 | # Testing
103 | with open('label.txt', 'w') as f:
104 | for i in range(len(train_data.samples)):
105 | f.write('{}\n'.format(train_data.samples[i][1]))
106 |
107 | valid_data = APTOS2019(path, train=False, transforms = test_transform)
108 |
109 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 16)
110 | valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8)
111 | return train_loader, valid_loader
112 |
113 |
114 | def get_mnist_noise_dataset(dataname, noise_rate = 0.2, batch_size = 32, seed = 0):
115 | # from medmnist import NoduleMNIST3D
116 | from medmnist import PathMNIST, BloodMNIST, OCTMNIST, TissueMNIST, OrganCMNIST
117 | train_transform, test_transform = get_transform()
118 |
119 | if dataname == 'pathmnist':
120 | train_data = PathMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True)
121 | test_data = PathMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True)
122 | num_classes = 9
123 | if dataname == 'bloodmnist':
124 | train_data = BloodMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True)
125 | test_data = BloodMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True)
126 | num_classes = 8
127 | if dataname == 'octmnist':
128 | train_data = OCTMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True)
129 | test_data = OCTMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True)
130 | num_classes = 4
131 | if dataname == 'tissuemnist':
132 | train_data = TissueMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True)
133 | test_data = TissueMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True)
134 | num_classes = 8
135 | if dataname == 'organcmnist':
136 | train_data = OrganCMNIST(split="train", download=True, size=224, transform= train_transform, as_rgb=True)
137 | test_data = OrganCMNIST(split="test", download=True, size=224, transform= test_transform, as_rgb=True)
138 | num_classes = 11
139 |
140 | np.random.seed(seed)
141 | # new_imgs = []
142 | new_labels =[]
143 | for i in range(len(train_data.imgs)):
144 | if np.random.rand() > noise_rate: # clean sample:
145 | # new_imgs.append(train_data.imgs[i])
146 | new_labels.append(train_data.labels[i][0])
147 | else:
148 | label_index = list(range(num_classes))
149 | label_index.remove(train_data.labels[i])
150 | label_index = np.array(label_index)
151 | label_index = np.reshape(label_index, (-1))
152 |
153 | new_label = np.random.choice(label_index, 1)
154 | new_label = new_label[0]
155 |
156 | # new_imgs.append(train_data.imgs[i])
157 | new_labels.append(new_label)
158 | # train_data.imgs = new_imgs
159 | train_data.labels = new_labels
160 |
161 | new_labels = []
162 | for i in range(len(test_data.labels)):
163 | new_labels.append(test_data.labels[i][0])
164 | test_data.labels = new_labels
165 |
166 |
167 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers = 16)
168 | valid_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers = 8)
169 | return train_loader, valid_loader
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | recall_level_default = 0.95
4 |
5 | def validation_accuracy(model, loader, device, mode = 'rein'):
6 | total = 0
7 | correct = 0
8 |
9 | def linear(model, inputs):
10 | f = model(inputs)
11 | outputs = model.linear(f)
12 | return outputs
13 |
14 | def rein(model, inputs):
15 | f = model.forward_features(inputs)
16 | f = f[:, 0, :]
17 | outputs = model.linear_rein(f)
18 | return outputs
19 |
20 | def no_rein(model, inputs):
21 | f = model.forward_features_no_rein(inputs)
22 | f = f[:, 0, :]
23 | outputs = model.linear(f)
24 | return outputs
25 | if mode == 'rein':
26 | out = rein
27 | elif mode == 'no_rein':
28 | out = no_rein
29 | else:
30 | out = linear
31 |
32 | model.eval()
33 | with torch.no_grad():
34 | for batch_idx, (inputs, targets) in enumerate(loader):
35 | inputs, targets = inputs.to(device), targets.to(device)
36 | outputs = out(model, inputs)
37 | _, predicted = outputs.max(1)
38 | correct += predicted.eq(targets).sum().item()
39 | total += targets.size(0)
40 | valid_accuracy = correct/total
41 | return valid_accuracy
--------------------------------------------------------------------------------