├── .gitignore
├── LICENSE
├── README.md
├── dataset.py
├── demo_classification.py
├── evaluate_classification.py
├── loss
├── DiceLoss.py
├── FocalLoss.py
├── WeightDiceLoss.py
├── metric.py
└── ssim.py
├── models
├── Dpt.py
├── __init__.py
├── adapter.py
├── dpt
│ ├── __init__.py
│ ├── base_model.py
│ ├── blocks.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── dino_head.py
│ │ ├── drop_path.py
│ │ ├── layer_scale.py
│ │ ├── mlp.py
│ │ ├── patch_embed.py
│ │ └── swiglu_ffn.py
│ ├── midas_net.py
│ ├── models.py
│ ├── transforms.py
│ └── vit.py
├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── block.py
│ ├── dino_head.py
│ ├── drop_path.py
│ ├── layer_scale.py
│ ├── mlp.py
│ ├── patch_embed.py
│ └── swiglu_ffn.py
├── unet.py
├── vision_transformer.py
└── vision_transformer_lora.py
├── requirements.txt
└── run
├── mla_crater.sh
├── mla_das.sh
├── mla_facies.sh
├── mla_fault.sh
└── mla_salt.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.pyc
3 | *.dat
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Zhixiang Guo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🌏 Cross-Domain Foundation Model Adaptation: Pioneering Computer Vision Models for Geophysical Data Analysis
2 |
3 |
4 | 🏢 [Computational Interpretation Group (CIG)](https://cig.ustc.edu.cn/main.htm)
5 |
6 | [Zhixiang Guo1](https://cig.ustc.edu.cn/guo/list.htm),
7 | [Xinming Wu1*](https://cig.ustc.edu.cn/xinming/list.htm),
8 | [Luming Liang2](https://www.microsoft.com/en-us/research/people/lulian/),
9 | [Hanlin Sheng1](https://cig.ustc.edu.cn/hanlin/list.htm),
10 | [Nuo Chen1](https://cig.ustc.edu.cn/nuo/list.htm),
11 | [Zhengfa Bi3](https://profiles.lbl.gov/416831-zhengfa-bi)
12 |
13 | School of Earth and Space Sciences, University of Science and Technology of China, Hefei, China
14 |
15 |
16 |
17 | Microsoft Applied Sciences Group, Redmond, WA 98052, United States
18 |
19 |
20 | Lawrence Berkeley National Laboratory, 1 Cyclotron Rd, CA 94707, USA
21 |
22 |
23 | ## :mega: News
24 | :flying_saucer: The dataset, model, code, and demo are coming soon!
25 |
26 | :collision: [2025.02.23]: The paper has been accepted for publication in [[JGR: Machine Learning and Computation](https://agupubs.onlinelibrary.wiley.com/doi/pdf/10.1029/2025JH000601)]
27 |
28 | :collision: [2024.09.01]: The code has been uploaded.
29 |
30 | :collision: [2024.08.23]: The paper has been submitted to Arxiv: https://arxiv.org/pdf/2408.12396
31 |
32 | :collision: [2024.07.23]: Upload the [dataset](https://github.com/ProgrammerZXG/Cross-Domain-Foundation-Model-Adaptation/blob/master/README.md#package-dataset).
33 |
34 | :collision: [2024.07.07]: Github Repository Initialization.
35 |
36 | ## :sparkles: Introduction
37 |
38 | Workflow for adapting pre-trained foundation models to geophysics.
39 | First, we prepare geophysical training datasets (1st column),
40 | which involves collecting and processing relevant geophysical data
41 | to ensure it is suitable for adaption fine-tuning. Next, we load the pre-trained
42 | foundation model as the data feature encoder (2nd column)
43 | and fine-tune the model to make it adaptable to geophysical data.
44 | To map the encoder features to the task-specific targets,
45 | we explore suitable decoders
46 | (3rd column) for geophysical downstream adaption. Finally, the adapted model
47 | is applied to various downstream tasks within the geophysics
48 | field (4th column).
49 |
50 |
51 |
52 |

53 |
54 |
55 |
56 | ## 🚀 Quick Start
57 |
58 | ### 1. Clone the repository
59 | Our code provides demos corresponding to the data mentioned in the paper,
60 | including seismic facies, geological bodies, DAS, faults, and craters.
61 | You can run them by following the steps below:
62 |
63 | First, clone the repository to your local machine:
64 |
65 | ```bash
66 |
67 | git clone git@github.com:ProgrammerZXG/Cross-Domain-Foundation-Model-Adaptation.git
68 | cd Cross-Domain-Foundation-Model-Adaptation
69 |
70 | ```
71 |
72 | ### 2. Install dependencies
73 |
74 | ```bash
75 |
76 | pip install -r requirements.txt
77 |
78 | ```
79 |
80 | ### 3. Download the dataset
81 |
82 | Before running the code, you need to download the dataset.
83 | You can download the dataset in [Zenodo](https://zenodo.org/records/12798750) and put them in the `data/`.
84 |
85 | ### 4. Run the code
86 |
87 | ```bash
88 |
89 | cd run
90 | bash mla_facies.sh
91 |
92 | ```
93 | If you choose to use `bash run/mla_facies.sh`, please be aware of the dataset path.
94 |
95 | ## :stars: Results
96 |
97 |
98 | ### Quantitative Metrics for Downstream Tasks
99 |
100 | #### Mean Intersection over Union (mIoU)
101 |
102 | | Network | Seismic Facies
Classification | Seismic Geobody
Identification | Crater
Detection | DAS Seismic
Event Detection | Deep Fault
Detection |
103 | |---------------|:------------:|:------------:|:------------:|:------------:|:------------:|
104 | | Unet | 0.5490 | 0.8636 | 0.5812 | 0.7271 | 0.6858 |
105 | | DINOv2-LINEAR | 0.6565 | 0.8965 | 0.6857 | 0.8112 | 0.6372 |
106 | | DINOv2-PUP | **0.6885** | 0.8935 | 0.6937 | 0.8487 | 0.7088 |
107 | | DINOv2-DPT | 0.6709 | 0.8912 | 0.6917 | **0.8672** | 0.7334 |
108 | | DINOv2-MLA | 0.6826 | **0.8969** | **0.6949** | 0.8591 | **0.7613** |
109 |
110 |
111 | #### Mean Pixel Accuracy (mPA)
112 |
113 | | Network | Seismic Facies
Classification | Seismic Geobody
Identification | Crater
Detection | DAS Seismic
Event Detection | Deep Fault
Detection |
114 | |---------------|:------------:|:------------:|:------------:|:------------:|:------------:|
115 | | Unet | 0.7693 | 0.9112 | 0.6265 | 0.7865 | 0.7439 |
116 | | DINOv2-LINEAR | 0.8732 | 0.9374 | 0.7481 | 0.9033 | 0.7519 |
117 | | DINOv2-PUP | **0.9102** | 0.9357 | 0.7529 | 0.9210 | 0.7793 |
118 | | DINOv2-DPT | 0.8826 | 0.9377 | 0.7462 | 0.9119 | 0.7985 |
119 | | DINOv2-MLA | 0.8975 | **0.9383** | **0.7476** |**0.9222** | **0.8195** |
120 |
121 | ## :package: Dataset
122 | All data is avalable at [Zenodo](https://zenodo.org/records/12798750).
123 |
124 | [](https://doi.org/10.5281/zenodo.12798750)
125 |
126 | | Task | Data Sources | Data Size | Training
Number | Test
Number |
127 | |------------------------------|-----------------------------------------------|--------------|-----------------|-------------|
128 | | Seismic Facies Classification| provided by [(SEAM, 2020)](https://www.aicrowd.com/challenges/seismic-facies-identification-challenge/discussion)
| 1006 × 782
| 250
| 45
|
129 | | Salt Body Identification | provided by
[(Addison Howard et al., 2018)](https://www.kaggle.com/competitions/tgs-salt-identification-challenge)
| 224 × 224
| 3000
| 1000
|
130 | | Crater Detection | original data provided by [CAS](https://moon.bao.ac.cn/),
labelled by authors
| 1022 × 1022
| 1000
| 199
|
131 | | DAS Seismic Event Detection | provided by [(Biondi et al., 2023)](https://zenodo.org/records/8270895)
| 512 × 512
| 115
| 28
|
132 | | Deep Fault Detection | original data provided
from field surveys,
labelled by authors
| 896 × 896
| 1081
| 269
|
133 |
134 | ## :bookmark: Citation
135 |
136 | If you find this work useful, please consider citing our paper:
137 |
138 | ```markdown
139 |
140 | @misc{guo2024crossdomainfoundationmodeladaptation,
141 | title={Cross-Domain Foundation Model Adaptation: Pioneering Computer Vision Models for Geophysical Data Analysis},
142 | author={Zhixiang Guo and Xinming Wu and Luming Liang and Hanlin Sheng and Nuo Chen and Zhengfa Bi},
143 | year={2024},
144 | eprint={2408.12396},
145 | archivePrefix={arXiv},
146 | primaryClass={cs.CV},
147 | url={https://arxiv.org/abs/2408.12396},
148 | }
149 | ```
150 |
151 | ## :memo: Acknowledgment
152 | This study is strongly supported by the Supercomputing
153 | Center of the University of Science and Technology of China,
154 | particularly with the provision of Nvidia 80G A100 GPUs,
155 | which are crucial for our experiments.
156 | We also thank [SEAM](https://seg.org/SEAM) for providing the seismic facies classification dataset,
157 | [TGS](https://www.kaggle.com/competitions/tgs-salt-identification-challenge) for the geobody identification dataset,
158 | [CAS](https://moon.bao.ac.cn) for the crater detection dataset,
159 | [Biondi](https://www.science.org/doi/full/10.1126/sciadv.adi9878) for the DAS seismic event detection dataset,
160 | and [CIG](https://cig.ustc.edu.cn/main.htm) for the deep fault detection dataset.
161 |
162 | ## :postbox: Contact
163 | If you have any questions about this work,
164 | please feel free to contact xinmwu@ustc.edu.cn or zxg3@mail.ustc.edu.cn.
165 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 | import torchvision.transforms as T
6 |
7 | class BasicDataset(Dataset):
8 |
9 | def __init__(self,patch_h,patch_w,datasetName,netType,train_mode = False):
10 |
11 | self.patch_h = patch_h
12 | self.patch_w = patch_w
13 |
14 | if netType == 'unet' or netType == 'deeplabv3plus':
15 | self.imgTrans = False
16 | else:
17 | self.imgTrans = True
18 |
19 | self.transform = T.Compose([
20 | T.Resize((patch_h * 14, patch_w * 14)),
21 | T.ToTensor(),
22 | ])
23 |
24 | self.dataset = datasetName
25 |
26 | if datasetName == 'seam':
27 | self.n1 = 1006
28 | self.n2 = 782
29 | # self.train_data_dir = '../data/seismicFace/train/input'
30 | # self.train_label_dir = '../data/seismicFace/train/target'
31 | # self.valid_data_dir = '../data/seismicFace/valid/input'
32 | # self.valid_label_dir = '../data/seismicFace/valid/target'
33 | self.train_data_dir = '/home/zxguo/data/seamai_1006x782/seamaiForTrain/input'
34 | self.train_label_dir = '/home/zxguo/data/seamai_1006x782/seamaiForTrain/target'
35 | self.valid_data_dir = '/home/zxguo/data/seamai_1006x782/seamaiForVal/input'
36 | self.valid_label_dir = '/home/zxguo/data/seamai_1006x782/seamaiForVal/target'
37 | elif datasetName == 'salt':
38 | self.n1 = 224
39 | self.n2 = 224
40 | self.train_data_dir = '../data/geobody/train/input'
41 | self.train_label_dir = '../data/geobody/train/target'
42 | self.valid_data_dir = '../data/geobody/valid/input'
43 | self.valid_label_dir = '../data/geobody/valid/target'
44 | elif datasetName == 'fault':
45 | self.n1 = 896
46 | self.n2 = 896
47 | self.train_data_dir = '../data/deepFault/train/image'
48 | self.train_label_dir = '../data/deepFault/train/label'
49 | self.valid_data_dir = '../data/deepFault/valid/image'
50 | self.valid_label_dir = '../data/deepFault/valid/label'
51 | elif datasetName == 'crater':
52 | self.n1 = 1022
53 | self.n2 = 1022
54 | self.train_data_dir = '../data/crater/train/image'
55 | self.train_label_dir = '../data/crater/train/label'
56 | self.valid_data_dir = '../data/crater/valid/image'
57 | self.valid_label_dir = '../data/crater/valid/label'
58 | elif datasetName == 'das':
59 | self.n1 = 512
60 | self.n2 = 512
61 | self.train_data_dir = '../data/das/train/image'
62 | self.train_label_dir = '../data/das/train/label'
63 | self.valid_data_dir = '../data/das/valid/image'
64 | self.valid_label_dir = '../data/das/valid/label'
65 | else:
66 | print("Dataset error!!")
67 | print('netType:' + netType)
68 | print('dataset:' + datasetName)
69 | print('patch_h:' + str(patch_h))
70 | print('patch_w:' + str(patch_w))
71 |
72 | if train_mode:
73 | self.data_dir = self.train_data_dir
74 | self.label_dir = self.train_label_dir
75 | else:
76 | self.data_dir = self.valid_data_dir
77 | self.label_dir = self.valid_label_dir
78 |
79 | self.ids = len(os.listdir(self.data_dir))
80 | def __len__(self):
81 | return self.ids
82 |
83 | def __getitem__(self,index):
84 |
85 | dPath = self.data_dir+'/'+str(index)+'.dat'
86 | tPath = self.label_dir+'/'+str(index)+'.dat'
87 | data = np.fromfile(dPath,np.float32).reshape(self.n1,self.n2)
88 | label = np.fromfile(tPath,np.int8).reshape(self.n1,self.n2)
89 |
90 | data = np.reshape(data,(1,1,self.n1,self.n2))
91 | data = np.concatenate([data,self.data_aug(data)],axis=0)
92 | label = np.reshape(label,(1,1,self.n1,self.n2))
93 | label = np.concatenate([label,self.data_aug(label)],axis=0)
94 |
95 | if self.imgTrans:
96 | img_tensor = np.zeros([2,1,self.patch_h*14,self.patch_w*14],np.float32)
97 | for i in range(data.shape[0]):
98 | img = Image.fromarray(np.uint8(data[i,0]))
99 | img_tensor[i,0] = self.transform(img)
100 | data = img_tensor
101 | data = data.repeat(3,axis=1)
102 | elif not self.imgTrans:
103 | data = data/255
104 |
105 | return data,label
106 |
107 | def data_aug(self,data):
108 | b,c,h,w = data.shape
109 | data_fliplr = np.fliplr(np.squeeze(data))
110 | return data_fliplr.reshape(b,c,h,w)
111 |
112 | if __name__ == '__main__':
113 |
114 | train_set = BasicDataset(72,56,'seam','setr1',True,True)
115 | print(train_set.__getitem__(0)[1].shape)
116 | print(len(train_set))
117 |
--------------------------------------------------------------------------------
/demo_classification.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | from torch import optim
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 | from dataset import BasicDataset
9 | from models.adapter import dinov2_mla,dinov2_pup,dinov2_linear
10 | from models.Dpt import dinov2_dpt
11 | from models.unet import U_Net
12 | import numpy as np
13 | from tensorboardX import SummaryWriter
14 | from torchmetrics.classification import JaccardIndex
15 | from loss.FocalLoss import Focal_Loss
16 | from loss.DiceLoss import DiceLoss
17 | from loss.WeightDiceLoss import WeightedDiceLoss
18 | import random
19 | import argparse
20 | # from logger import Logger
21 | import loralib as lora
22 |
23 | random.seed(1234)
24 | np.random.seed(1234)
25 | torch.manual_seed(1234)
26 | torch.cuda.manual_seed(1234)
27 | torch.cuda.manual_seed_all(1234)
28 |
29 |
30 | def main(args,logger):
31 | dir_checkpoint = '../checkpoint/' + args.dataset + "/" + args.loss + "/" +args.netType
32 | if not os.path.exists(dir_checkpoint):
33 | os.makedirs(dir_checkpoint)
34 |
35 | if args.dataset == 'seam':
36 | args.n1, args.n2 = 1006, 782
37 | args.classes = 6
38 | args.patch_h = 72
39 | args.patch_w = 56
40 | args.batch_size = 3
41 | elif args.dataset == 'salt':
42 | args.n1, args.n2 = 224, 224
43 | args.classes = 2
44 | args.patch_h = 20
45 | args.patch_w = 20
46 | args.batch_size = 32
47 | elif args.dataset == 'crater':
48 | args.n1, args.n2 = 1022, 1022
49 | args.classes = 2
50 | args.patch_h = 73
51 | args.patch_w = 73
52 | args.batch_size = 3
53 | elif args.dataset == 'das':
54 | args.n1, args.n2 = 512, 512
55 | args.classes = 2
56 | args.patch_h = 37
57 | args.patch_w = 37
58 | args.batch_size = 6
59 | elif args.dataset == 'fault':
60 | args.n1, args.n2 = 896, 896
61 | args.classes = 2
62 | args.patch_h = 64
63 | args.patch_w = 64
64 | args.batch_size = 6
65 |
66 | if args.checkpointName in ["unfrozen","lora"]:
67 | frozen = False
68 | elif args.checkpointName == "frozen":
69 | frozen = True
70 |
71 | if args.netType == "unet":
72 | net = U_Net(1,args.classes)
73 | elif args.netType == "linear":
74 | net = dinov2_linear(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName)
75 | elif args.netType == "mla":
76 | net = dinov2_mla(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName)
77 | elif args.netType == "pup":
78 | net = dinov2_pup(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName)
79 | elif args.netType == "dpt":
80 | net = dinov2_dpt(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName)
81 |
82 | logger.info(f'\t{args.netType} NetWork:\n'
83 | f'\t{args.classes } num classes\n'
84 | f'\t{args.dataset} dataset\n'
85 | f'\t{args.vt} vitType\n'
86 | f'\t{args.loss} loss\n')
87 | # net = torch.nn.DataParallel(net, device_ids=range(device_count))
88 | goTrain(args,
89 | dir_checkpoint,
90 | net=net,
91 | patch_h = args.patch_h,
92 | patch_w = args.patch_w,
93 | epochs=args.epochs,
94 | batch_size= int(args.batch_size),
95 | learning_rate= args.lr,
96 | num_classes = args.classes,
97 | save_checkpoint=args.save_checkpoint
98 | )
99 | def goTrain(args,
100 | dir_checkpoint,
101 | net,
102 | patch_h,
103 | patch_w,
104 | num_classes : int,
105 | epochs:int = 5,
106 | batch_size: int = 1,
107 | learning_rate: float = 1e-4,
108 | save_checkpoint: bool = True):
109 |
110 | net.to(device)
111 | get_parameter_number(net)
112 |
113 | # Create dataset
114 | train_set = BasicDataset(patch_h, patch_w, args.dataset,args.netType, train_mode=True)
115 | valid_set = BasicDataset(patch_h, patch_w, args.dataset,args.netType, train_mode=False)
116 |
117 | #Create data loaders
118 | train_loader= DataLoader(dataset = train_set,batch_size = batch_size, shuffle=True)
119 | valid_loader= DataLoader(dataset = valid_set,batch_size = batch_size, shuffle=False)
120 |
121 | logger.info(f'''Starting training:
122 | Epochs: {epochs}
123 | Batch size: {batch_size}
124 | Learning rate: {learning_rate}
125 | Training size: {len(train_set)}
126 | Validation size: {len(valid_set)}
127 | Checkpoints: {save_checkpoint}
128 | ''')
129 |
130 | jaccard = JaccardIndex(task='multiclass',num_classes=num_classes).to(device)
131 | # Set up the optimizer, the loss, the learning rate scheduler and the loss scaling
132 | # optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-8)
133 | # optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=0.05)
134 | optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=0.01,betas=[0.7,0.999])
135 | if args.al:
136 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
137 | if args.loss == "ce":
138 | criterion = nn.CrossEntropyLoss()
139 | elif args.loss == "bce":
140 | criterion = nn.BCEWithLogitsLoss()
141 | elif args.loss == "focal":
142 | criterion = Focal_Loss(args.classes,device=args.device)
143 | elif args.loss == "dice":
144 | criterion = DiceLoss(args.classes)
145 | elif args.loss == "wdice":
146 | criterion = WeightedDiceLoss(args.classes,device=args.device)
147 | elif args.loss == "bace":
148 | if args.dataset == "seam":
149 | weight = torch.tensor([1.216,0.395,3.673,0.573,14.193,1.798]).reshape(-1,1).to(args.device)
150 | criterion = nn.CrossEntropyLoss(weight=weight)
151 |
152 | #Tensorboard open
153 | writer = SummaryWriter('../Tensorboard/'+args.dataset+'/' + args.loss + '/')
154 |
155 | # Begin training
156 | train_loss = []
157 | valid_loss=[]
158 | train_iou = []
159 | valid_iou = []
160 | train_pa = []
161 | valid_pa = []
162 | MaxTrainIoU = 0
163 | MaxValidIoU = 0
164 | MinTrainLoss = 1e7
165 | MinValidLoss = 1e7
166 | net.train()
167 | warmup_steps = 10
168 | ini_lr = learning_rate*10
169 | for epoch in range(1,epochs+1):
170 | if args.al=="True":
171 | if epoch < warmup_steps:
172 | warmup_percent_done = epoch/warmup_steps
173 | optimizer.param_groups[0]['lr'] = ini_lr * warmup_percent_done
174 | else:
175 | scheduler.step()
176 | total_train_loss = []
177 | total_valid_loss = []
178 | total_train_iou = []
179 | total_valid_iou = []
180 | total_train_pa = []
181 | total_valid_pa = []
182 | with tqdm(total = len(train_set),desc=f'Epoch {epoch}/{epochs}',unit = 'img') as t:
183 | for data,label in train_loader:
184 | b1,b2,c,h,w = data.shape
185 | data = data.to(device).reshape(b1*b2,c,h,w)
186 | b1,b2,c,h,w = label.shape
187 | label = label.to(device).reshape(b1*b2,h,w)
188 | optimizer.zero_grad()
189 | outputs = net(data,(args.n1,args.n2))
190 | if args.loss == "bce":
191 | loss = criterion(outputs,label.unsqueeze(1).expand(-1, 2, -1, -1).float())
192 | else:
193 | loss = criterion(outputs,label.long())
194 | _, preds = torch.max(outputs, 1)
195 | iou_tmp = jaccard(preds,label.long()).detach().cpu().numpy()
196 | pa_tmp = ((preds == label).sum().item() / (b1*b2*h*w))
197 | loss.backward()
198 | optimizer.step()
199 | t.update(batch_size)
200 | t.set_postfix(**{'train_loss': loss.item(),'iou': iou_tmp,'accuracy':pa_tmp,'lr': optimizer.param_groups[0]['lr']})
201 | total_train_loss.append(loss.item())
202 | total_train_iou.append(iou_tmp)
203 | total_train_pa.append(pa_tmp)
204 | train_loss.append(np.mean(total_train_loss))
205 | train_iou.append(np.mean(total_train_iou))
206 | train_pa.append(np.mean(total_train_pa))
207 | logger.info(f"Epoch {epoch} - TrainSet - Loss: {train_loss[-1]}, IoU: {train_iou[-1]}, Accuracy: {train_pa[-1]}")
208 |
209 | # if save_checkpoint and epoch%5==0:
210 | # torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_epoch"+str(epoch)+"_train.pth")
211 | if train_iou[-1]>MaxTrainIoU:
212 | torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_train.pth")
213 | if args.checkpointName=="lora":
214 | torch.save(lora.lora_state_dict(net), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_train_lora.pth")
215 | MaxTrainIoU = train_iou[-1]
216 | logger.info(f'max_train_iou saved!')
217 | if train_loss[-1]MaxValidIoU:
254 | torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_valid.pth")
255 | if args.checkpointName=="lora":
256 | torch.save(lora.lora_state_dict(net), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_valid_lora.pth")
257 | MaxValidIoU = valid_iou[-1]
258 | logger.info(f'max_valid_iou saved!')
259 | if valid_loss[-1]= 0) & (imgLabel < self.numClass)
38 | label = self.numClass * imgLabel[mask] + imgPredict[mask]
39 | count = np.bincount(label, minlength=self.numClass**2)
40 | confusionMatrix = count.reshape(self.numClass, self.numClass)
41 | return confusionMatrix
42 |
43 | def Frequency_Weighted_Intersection_over_Union(self):
44 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
45 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
46 | iu = np.diag(self.confusion_matrix) / (
47 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
48 | np.diag(self.confusion_matrix))
49 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
50 | return FWIoU
51 |
52 |
53 | def addBatch(self, imgPredict, imgLabel):
54 | assert imgPredict.shape == imgLabel.shape
55 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
56 |
57 | def reset(self):
58 | self.confusionMatrix = np.zeros((self.numClass, self.numClass))
--------------------------------------------------------------------------------
/loss/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import torch.nn as nn
5 | def _fspecial_gaussian(size, channel, sigma):
6 | coords = torch.tensor([(x - (size - 1.) / 2.) for x in range(size)])
7 | coords = -coords ** 2 / (2. * sigma ** 2)
8 | grid = coords.view(1, -1) + coords.view(-1, 1)
9 | grid = grid.view(1, -1)
10 | grid = grid.softmax(-1)
11 | kernel = grid.view(1, 1, size, size)
12 | kernel = kernel.expand(channel, 1, size, size).contiguous()
13 | return kernel
14 |
15 | # zfbi
16 | def _fspecial_gaussian3d(size, channel, sigma):
17 | coords = torch.tensor([(x - (size - 1.) / 2.) for x in range(size)])
18 | coords = -coords ** 2 / (2. * sigma ** 2)
19 | grid = coords.view(1, -1, 1) + coords.view(-1, 1, 1) + coords.view(1, 1, -1)
20 | grid = grid.view(1, -1)
21 | grid = grid.softmax(-1)
22 | kernel = grid.view(1, 1, size, size, size)
23 | kernel = kernel.expand(channel, 1, size, size, size).contiguous()
24 | return kernel
25 |
26 | def _ssim(output, target, max_val, k1, k2, channel, kernel):
27 | c1 = (k1 * max_val) ** 2
28 | c2 = (k2 * max_val) ** 2
29 |
30 | mu1 = F.conv2d(output, kernel, groups=channel)
31 | mu2 = F.conv2d(target, kernel, groups=channel)
32 |
33 | mu1_sq = mu1 ** 2
34 | mu2_sq = mu2 ** 2
35 | mu1_mu2 = mu1 * mu2
36 |
37 | sigma1_sq = F.conv2d(output * output, kernel, groups=channel) - mu1_sq
38 | sigma2_sq = F.conv2d(target * target, kernel, groups=channel) - mu2_sq
39 | sigma12 = F.conv2d(output * target, kernel, groups=channel) - mu1_mu2
40 |
41 | v1 = 2 * sigma12 + c2
42 | v2 = sigma1_sq + sigma2_sq + c2
43 |
44 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2)
45 | return ssim, v1 / v2
46 |
47 | # zfbi
48 | def _ssim3d(input, target, max_val, k1, k2, channel, kernel):
49 | c1 = (k1 * max_val) ** 2
50 | c2 = (k2 * max_val) ** 2
51 |
52 | mu1 = F.conv3d(input, kernel, groups=channel)
53 | mu2 = F.conv3d(target, kernel, groups=channel)
54 |
55 | mu1_sq = mu1 ** 2
56 | mu2_sq = mu2 ** 2
57 | mu1_mu2 = mu1 * mu2
58 |
59 | sigma1_sq = F.conv3d(input * input, kernel, groups=channel) - mu1_sq
60 | sigma2_sq = F.conv3d(target * target, kernel, groups=channel) - mu2_sq
61 | sigma12 = F.conv3d(input * target, kernel, groups=channel) - mu1_mu2
62 |
63 | v1 = 2 * sigma12 + c2
64 | v2 = sigma1_sq + sigma2_sq + c2
65 |
66 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2)
67 | return ssim, v1 / v2
68 |
69 |
70 | def ssim_loss(input, target, max_val, filter_size=7, k1=0.01, k2=0.03,
71 | sigma=1.5, kernel=None, size_average=None, reduce=None, reduction='mean'):
72 |
73 | if input.size() != target.size():
74 | raise ValueError('Expected input size ({}) to match target size ({}).'
75 | .format(input.size(0), target.size(0)))
76 |
77 | if size_average is not None or reduce is not None:
78 | reduction = _Reduction.legacy_get_string(size_average, reduce)
79 |
80 | dim = input.dim()
81 | if dim == 2:
82 | input = input.expand(1, 1, input.dim(-2), input.dim(-1))
83 | target = target.expand(1, 1, target.dim(-2), target.dim(-1))
84 | elif dim == 3:
85 | input = input.expand(1, input.dim(-3), input.dim(-2), input.dim(-1))
86 | target = target.expand(1, target.dim(-3), target.dim(-2), target.dim(-1))
87 | elif dim != 4:
88 | raise ValueError('Expected 2, 3, or 4 dimensions (got {})'.format(dim))
89 |
90 | _, channel, _, _ = input.size()
91 |
92 | if kernel is None:
93 | kernel = _fspecial_gaussian(filter_size, channel, sigma)
94 | kernel = kernel.to(device=input.device)
95 |
96 | ret, _ = _ssim(input, target, max_val, k1, k2, channel, kernel)
97 |
98 | if reduction != 'none':
99 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
100 | return ret
101 |
102 | def ssim_loss3d(input, target, max_val, filter_size=7, k1=0.01, k2=0.03,
103 | sigma=1.5, kernel=None, size_average=None, reduce=None, reduction='mean'):
104 |
105 | if input.size() != target.size():
106 | raise ValueError('Expected input size ({}) to match target size ({}).'
107 | .format(input.size(0), target.size(0)))
108 |
109 | if size_average is not None or reduce is not None:
110 | reduction = _Reduction.legacy_get_string(size_average, reduce)
111 |
112 | dim = input.dim()
113 | if dim == 2:
114 | input = input.expand(1, 1, 1, input.dim(-2), input.dim(-1))
115 | target = target.expand(1, 1, 1, target.dim(-2), target.dim(-1))
116 | elif dim == 3:
117 | input = input.expand(1, 1, input.dim(-3), input.dim(-2), input.dim(-1))
118 | target = target.expand(1, 1, target.dim(-3), target.dim(-2), target.dim(-1))
119 | elif dim == 4:
120 | input = input.expand(1, input.dim(-4), input.dim(-3), input.dim(-2), input.dim(-1))
121 | target = target.expand(1, target.dim(-4), target.dim(-3), target.dim(-2), target.dim(-1))
122 | elif dim != 5:
123 | raise ValueError('Expected 2, 3, 4, or 5 dimensions (got {})'.format(dim))
124 |
125 | _, channel, _, _, _ = input.size()
126 |
127 | if kernel is None:
128 | kernel = _fspecial_gaussian3d(filter_size, channel, sigma)
129 | kernel = kernel.to(device=input.device)
130 |
131 | ret, _ = _ssim3d(input, target, max_val, k1, k2, channel, kernel)
132 |
133 | if reduction != 'none':
134 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
135 | return ret
136 |
137 | def ms_ssim_loss(input, target, max_val, filter_size=7, k1=0.01, k2=0.03,
138 | sigma=1.5, kernel=None, weights=None, size_average=None, reduce=None, reduction='mean'):
139 |
140 | if input.size() != target.size():
141 | raise ValueError('Expected input size ({}) to match target size ({}).'
142 | .format(input.size(0), target.size(0)))
143 |
144 | if size_average is not None or reduce is not None:
145 | reduction = _Reduction.legacy_get_string(size_average, reduce)
146 |
147 | dim = input.dim()
148 | if dim == 2:
149 | input = input.expand(1, 1, input.shape[-2], input.shape[-1])
150 | target = target.expand(1, 1, target.shape[-2], target.shape[-1])
151 | elif dim == 3:
152 | input = input.expand(1, input.dim(-3), input.dim(-2), input.dim(-1))
153 | target = target.expand(1, target.dim(-3), target.dim(-2), target.dim(-1))
154 | elif dim != 4:
155 | raise ValueError('Expected 2, 3, or 4 dimensions (got {})'.format(dim))
156 |
157 | _, channel, _, _ = input.size()
158 |
159 | if kernel is None:
160 | kernel = _fspecial_gaussian(filter_size, channel, sigma)
161 | kernel = kernel.to(device=input.device)
162 |
163 | if weights is None:
164 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
165 | weights = torch.tensor(weights, device=input.device)
166 | weights = weights.unsqueeze(-1).unsqueeze(-1)
167 | levels = weights.size(0)
168 | mssim = []
169 | mcs = []
170 | for _ in range(levels):
171 | ssim, cs = _ssim(input, target, max_val, k1, k2, channel, kernel)
172 | ssim = ssim.mean((2, 3))
173 | cs = cs.mean((2, 3))
174 | mssim.append(ssim)
175 | mcs.append(cs)
176 |
177 | input = F.avg_pool2d(input, (2, 2))
178 | target = F.avg_pool2d(target, (2, 2))
179 |
180 | mssim = torch.stack(mssim)
181 | mcs = torch.stack(mcs)
182 | # Normalize
183 | mssim = (mssim + 1) / 2
184 | mcs = (mcs + 1) / 2
185 | p1 = mcs ** weights
186 | p2 = mssim ** weights
187 |
188 | ret = torch.prod(p1[:-1], 0) * p2[-1]
189 |
190 | if reduction != 'none':
191 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
192 | return ret
193 |
194 |
195 | # zfbi
196 | def ms_ssim_loss3d(input, target, max_val, filter_size=7, k1=0.01, k2=0.03,
197 | sigma=1.5, kernel=None, weights=None, size_average=None, reduce=None, reduction='mean'):
198 |
199 | if input.size() != target.size():
200 | raise ValueError('Expected input size ({}) to match target size ({}).'
201 | .format(input.size(0), target.size(0)))
202 |
203 | if size_average is not None or reduce is not None:
204 | reduction = _Reduction.legacy_get_string(size_average, reduce)
205 |
206 | dim = input.dim()
207 | if dim == 2:
208 | input = input.expand(1, 1, 1, input.dim(-2), input.dim(-1))
209 | target = target.expand(1, 1, 1, target.dim(-2), target.dim(-1))
210 | elif dim == 3:
211 | input = input.expand(1, 1, input.dim(-3), input.dim(-2), input.dim(-1))
212 | target = target.expand(1, 1, target.dim(-3), target.dim(-2), target.dim(-1))
213 | elif dim == 4:
214 | input = input.expand(1, input.dim(-4), input.dim(-3), input.dim(-2), input.dim(-1))
215 | target = target.expand(1, target.dim(-4), target.dim(-3), target.dim(-2), target.dim(-1))
216 | elif dim != 5:
217 | raise ValueError('Expected 2, 3, 4, or 5 dimensions (got {})'.format(dim))
218 |
219 | _, channel, _, _, _ = input.size()
220 |
221 | if kernel is None:
222 | kernel = _fspecial_gaussian3d(filter_size, channel, sigma)
223 | kernel = kernel.to(device=input.device)
224 |
225 | if weights is None:
226 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
227 | weights = torch.tensor(weights, device=input.device)
228 | weights = weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
229 | levels = weights.size(0)
230 | mssim = []
231 | mcs = []
232 | for _ in range(levels):
233 | ssim, cs = _ssim3d(input, target, max_val, k1, k2, channel, kernel)
234 | ssim = ssim.mean((2, 3, 4))
235 | cs = cs.mean((2, 3, 4))
236 | mssim.append(ssim)
237 | mcs.append(cs)
238 |
239 | input = F.avg_pool3d(input, (2, 2, 2))
240 | target = F.avg_pool3d(target, (2, 2, 2))
241 |
242 | mssim = torch.stack(mssim)
243 | mcs = torch.stack(mcs)
244 | # Normalize
245 | mssim = (mssim + 1) / 2
246 | mcs = (mcs + 1) / 2
247 | p1 = mcs ** weights
248 | p2 = mssim ** weights
249 |
250 | ret = torch.prod(p1[:-1], 0) * p2[-1]
251 |
252 | if reduction != 'none':
253 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
254 | return ret
255 |
256 | class _Loss(torch.nn.Module):
257 | def __init__(self, size_average=None, reduce=None, reduction='mean'):
258 | super(_Loss, self).__init__()
259 | if size_average is not None or reduce is not None:
260 | self.reduction = _Reduction.legacy_get_string(size_average, reduce)
261 | else:
262 | self.reduction = reduction
263 |
264 | class SSIMLoss(_Loss):
265 |
266 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
267 |
268 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'):
269 | super(SSIMLoss, self).__init__(size_average, reduce, reduction)
270 | self.filter_size = filter_size
271 | self.k1 = k1
272 | self.k2 = k2
273 | self.sigma = sigma
274 | self.kernel = _fspecial_gaussian(filter_size, channel, sigma)
275 |
276 | def forward(self, input, target, max_val=1.):
277 | return ssim_loss(input, target, max_val=max_val, filter_size=self.filter_size, k1=self.k1, k2=self.k2,
278 | sigma=self.sigma, reduction=self.reduction, kernel=self.kernel)
279 |
280 | class SSIMLoss3D(_Loss):
281 |
282 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
283 |
284 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'):
285 | super(SSIMLoss3D, self).__init__(size_average, reduce, reduction)
286 | self.filter_size = filter_size
287 | self.k1 = k1
288 | self.k2 = k2
289 | self.sigma = sigma
290 | self.kernel = _fspecial_gaussian3d(filter_size, channel, sigma)
291 |
292 | def forward(self, input, target, max_val=1.):
293 | return ssim_loss3d(input, target, max_val=max_val, filter_size=self.filter_size, k1=self.k1, k2=self.k2,
294 | sigma=self.sigma, reduction=self.reduction, kernel=self.kernel)
295 |
296 | class MultiScaleSSIMLoss(_Loss):
297 |
298 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
299 |
300 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'):
301 | super(MultiScaleSSIMLoss, self).__init__(size_average, reduce, reduction)
302 | self.filter_size = filter_size
303 | self.k1 = k1
304 | self.k2 = k2
305 | self.sigma = sigma
306 | self.kernel = _fspecial_gaussian(filter_size, channel, sigma)
307 |
308 | def forward(self, input, target, weights=[0.0448, 0.2856, 0.3001, 0.2363, 0.1333], max_val=1.):
309 | return ms_ssim_loss(input, target, max_val=max_val, k1=self.k1, k2=self.k2, sigma=self.sigma, kernel=self.kernel,
310 | weights=weights, filter_size=self.filter_size, reduction=self.reduction)
311 | # zfbi
312 | class MultiScaleSSIMLoss3D(_Loss):
313 |
314 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction']
315 |
316 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'):
317 | super(MultiScaleSSIMLoss3D, self).__init__(size_average, reduce, reduction)
318 | self.filter_size = filter_size
319 | self.k1 = k1
320 | self.k2 = k2
321 | self.sigma = sigma
322 | self.kernel = _fspecial_gaussian3d(filter_size, channel, sigma)
323 |
324 | def forward(self, input, target, weights=[0.0448, 0.2856, 0.3001, 0.2363, 0.1333], max_val=1.):
325 | return ms_ssim_loss3d(input, target, max_val=max_val, k1=self.k1, k2=self.k2, sigma=self.sigma, kernel=self.kernel,
326 | weights=weights, filter_size=self.filter_size, reduction=self.reduction)
327 |
328 | class MSSIMLoss(nn.Module):
329 | def __init__(self, channel, filter_size):
330 | super(MSSIMLoss, self).__init__()
331 | self.mssim = MultiScaleSSIMLoss(channel=channel, filter_size=filter_size)
332 | def forward(self, output, target):
333 | loss = (1 - self.mssim(output, target))
334 | return loss
335 |
336 | class NSSIMLoss(nn.Module):
337 | def __init__(self, channel, filter_size):
338 | super(NSSIMLoss, self).__init__()
339 | self.ssim = SSIMLoss(channel=channel, filter_size=filter_size)
340 | def forward(self, output, target):
341 | loss = (1 - self.ssim(output, target))
342 | return loss
--------------------------------------------------------------------------------
/models/Dpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.vision_transformer_lora import vit_small_lora,vit_base_lora
5 | from models.vision_transformer import vit_small,vit_base
6 | from models.dpt import _make_fusion_block,_make_scratch
7 | import logging
8 | import loralib as lora
9 | ########################################################################################################################
10 |
11 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
12 |
13 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
14 | logger = logging.getLogger("dinov2")
15 | state_dict = torch.load(pretrained_weights, map_location="cpu")
16 | if checkpoint_key is not None and checkpoint_key in state_dict:
17 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
18 | state_dict = state_dict[checkpoint_key]
19 | # remove `module.` prefix
20 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
21 | # remove `backbone.` prefix induced by multicrop wrapper
22 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
23 | msg = model.load_state_dict(state_dict, strict=False)
24 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
25 |
26 | def make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
27 | compact_arch_name = arch_name.replace("_", "")[:4]
28 | return f"dinov2_{compact_arch_name}{patch_size}"
29 |
30 | def make_vit_encoder(dino_pretrain="False",vit_type="small",finetune_method="unfrozen"):
31 | vit_kwargs = dict(
32 | in_chans = 3,
33 | img_size=224,
34 | patch_size=14,
35 | init_values=1.0e-05,
36 | ffn_layer="mlp",
37 | block_chunks=0,
38 | qkv_bias=True,
39 | proj_bias=True,
40 | ffn_bias=True
41 | )
42 | if dino_pretrain == "True":
43 | model_name = make_dinov2_model_name("vit_"+vit_type, 14)
44 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
45 | pretrained_weights = torch.hub.load_state_dict_from_url(url, map_location="cpu")
46 | if finetune_method == "unfrozen" or finetune_method == "frozen":
47 | if vit_type == "small":
48 | encoder = vit_small(**vit_kwargs)
49 | emb = 384
50 | strict = True
51 | elif vit_type == "base":
52 | encoder = vit_base(**vit_kwargs)
53 | emb = 768
54 | strict = True
55 | else:
56 | print("Error in vit_type!!!")
57 | elif finetune_method == "lora":
58 | if vit_type == "small":
59 | encoder = vit_small_lora(**vit_kwargs)
60 | emb = 384
61 | strict = False
62 | elif vit_type == "base":
63 | encoder = vit_base_lora(**vit_kwargs)
64 | emb = 768
65 | strict = False
66 | else:
67 | print("Error in vit_type!!!")
68 | if dino_pretrain == "True":
69 | encoder.load_state_dict(pretrained_weights, strict=strict)
70 | return encoder,emb
71 |
72 | class dinov2_dpt(nn.Module):
73 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"):
74 | super(dinov2_dpt,self).__init__()
75 |
76 | features = 256
77 |
78 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method)
79 | self.scratch = _make_scratch([self.emb,self.emb,self.emb,self.emb],
80 | out_shape=features)
81 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn=True)
82 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn=True)
83 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn=True)
84 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn=True)
85 |
86 | self.scratch.single_conv = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
87 | self.scratch.output_conv = nn.Sequential(
88 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
89 | nn.ReLU(True),
90 | nn.Conv2d(32, num_classes, kernel_size=1, stride=1, padding=0),
91 | nn.ReLU(True),
92 | nn.Identity(),
93 | )
94 |
95 | if frozen:
96 | for param in self.encoder.parameters():
97 | param.requires_grad = False
98 | else:
99 | if finetune_method == "unfrozen":
100 | for param in self.encoder.parameters():
101 | param.requires_grad = True
102 | elif finetune_method == "lora":
103 | lora.mark_only_lora_as_trainable(self.encoder)
104 |
105 | def forward(self,x,size):
106 | B,_,H,W = x.shape
107 | _, x_middle = self.encoder.forward_features(x)
108 | xm = []
109 | for k,x in x_middle.items():
110 | x = x.view(
111 | x.size(0),
112 | int(H / 14),
113 | int(W / 14),
114 | self.emb,
115 | )
116 | x = x.permute(0, 3, 1, 2).contiguous()
117 | xm.append(x)
118 | layer_1, layer_2, layer_3, layer_4 = xm
119 | layer_1_rn = self.scratch.layer1_rn(layer_1)
120 | layer_2_rn = self.scratch.layer2_rn(layer_2)
121 | layer_3_rn = self.scratch.layer3_rn(layer_3)
122 | layer_4_rn = self.scratch.layer4_rn(layer_4)
123 |
124 | path_4 = self.scratch.refinenet4((size[0]//16,size[1]//16), layer_4_rn)
125 | path_3 = self.scratch.refinenet3((size[0]//8,size[1]//8),path_4, layer_3_rn)
126 | path_2 = self.scratch.refinenet2((size[0]//4,size[1]//4),path_3, layer_2_rn)
127 | path_1 = self.scratch.refinenet1((size[0]//2,size[1]//2),path_2, layer_1_rn)
128 |
129 | out = self.scratch.single_conv(path_1)
130 | out = F.interpolate(out,size=size)
131 | out = self.scratch.output_conv(out)
132 | return out
133 |
134 | if __name__ == "__main__":
135 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
136 | model = dinov2_dpt(1).to(device=device)
137 | total_num = sum(p.numel() for p in model.parameters())
138 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
139 | print('Model Total: %d'%total_num)
140 | print('Model Trainable: %d'%trainable_num)
141 | x1 = torch.Tensor(1,3,434,994).to(device=device,dtype=torch.float32)
142 | y1 = model(x1,size=(434,994))
143 | print(x1.shape)
144 | print(y1.shape)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .vision_transformer import vit_small
8 | from .layers import *
9 |
--------------------------------------------------------------------------------
/models/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.transforms as T
5 | from models.vision_transformer_lora import vit_small_lora,vit_base_lora
6 | from models.vision_transformer import vit_small,vit_base
7 | import fvcore.nn.weight_init as weight_init
8 |
9 | import logging
10 | import loralib as lora
11 |
12 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
13 |
14 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
15 | logger = logging.getLogger("dinov2")
16 | state_dict = torch.load(pretrained_weights, map_location="cpu")
17 | if checkpoint_key is not None and checkpoint_key in state_dict:
18 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
19 | state_dict = state_dict[checkpoint_key]
20 | # remove `module.` prefix
21 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
22 | # remove `backbone.` prefix induced by multicrop wrapper
23 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
24 | msg = model.load_state_dict(state_dict, strict=False)
25 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
26 |
27 | def make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
28 | compact_arch_name = arch_name.replace("_", "")[:4]
29 | return f"dinov2_{compact_arch_name}{patch_size}"
30 |
31 | def make_vit_encoder(dino_pretrain="False",vit_type="small",finetune_method="unfrozen"):
32 | vit_kwargs = dict(
33 | in_chans = 3,
34 | img_size=224,
35 | patch_size=14,
36 | init_values=1.0e-05,
37 | ffn_layer="mlp",
38 | block_chunks=0,
39 | qkv_bias=True,
40 | proj_bias=True,
41 | ffn_bias=True
42 | )
43 | if dino_pretrain == "True":
44 | model_name = make_dinov2_model_name("vit_"+vit_type, 14)
45 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
46 | pretrained_weights = torch.hub.load_state_dict_from_url(url, map_location="cpu")
47 | if finetune_method == "unfrozen" or finetune_method == "frozen":
48 | if vit_type == "small":
49 | encoder = vit_small(**vit_kwargs)
50 | emb = 384
51 | strict = True
52 | elif vit_type == "base":
53 | encoder = vit_base(**vit_kwargs)
54 | emb = 768
55 | strict = True
56 | else:
57 | print("Error in vit_type!!!")
58 | elif finetune_method == "lora":
59 | if vit_type == "small":
60 | encoder = vit_small_lora(**vit_kwargs)
61 | emb = 384
62 | strict = False
63 | elif vit_type == "base":
64 | encoder = vit_base_lora(**vit_kwargs)
65 | emb = 768
66 | strict = False
67 | else:
68 | print("Error in vit_type!!!")
69 | if dino_pretrain == "True":
70 | encoder.load_state_dict(pretrained_weights, strict=strict)
71 | return encoder,emb
72 |
73 | class IntermediateSequential(nn.Sequential):
74 | def __init__(self, *args, return_intermediate=True):
75 | super().__init__(*args)
76 | self.return_intermediate = return_intermediate
77 |
78 | def forward(self, input):
79 | if not self.return_intermediate:
80 | return super().forward(input)
81 |
82 | intermediate_outputs = {}
83 | output = input
84 | for name, module in self.named_children():
85 | output = intermediate_outputs[name] = module(output)
86 |
87 | return output, intermediate_outputs
88 |
89 |
90 | class SETR_PUP(nn.Module):
91 | def __init__(self,embedding_dim,num_classes):
92 | super(SETR_PUP,self).__init__()
93 |
94 | self.embedding_dim = embedding_dim
95 | self.num_classes = num_classes
96 |
97 | extra_in_channels = int(self.embedding_dim/4)
98 | in_channels = [
99 | self.embedding_dim,
100 | extra_in_channels,
101 | extra_in_channels,
102 | extra_in_channels,
103 | ]
104 | out_channels = [
105 | extra_in_channels,
106 | extra_in_channels,
107 | extra_in_channels,
108 | extra_in_channels,
109 | ]
110 |
111 | modules = []
112 | for i, (in_channel, out_channel) in enumerate(
113 | zip(in_channels, out_channels)
114 | ):
115 | modules.append(
116 | self.conv_block(in_channel,out_channel)
117 | )
118 | modules.append(nn.Upsample(size=(1//(2**(3-i)),1//(2**(3-i))), mode='bilinear'))
119 |
120 | modules.append(
121 | nn.Conv2d(
122 | in_channels=out_channels[-1], out_channels=self.num_classes,
123 | kernel_size=1, stride=1,
124 | padding=self._get_padding('VALID', (1, 1),),
125 | ))
126 | self.decode_net = IntermediateSequential(
127 | *modules, return_intermediate=False
128 | )
129 |
130 | def forward(self,x,size):
131 | n1,n2 = size
132 | self.decode_net[1] = nn.Upsample(size=(n1//(2**(3)),n2//(2**(3))), mode='bilinear')
133 | self.decode_net[3] = nn.Upsample(size=(n1//(2**(2)),n2//(2**(2))), mode='bilinear')
134 | self.decode_net[5] = nn.Upsample(size=(n1//(2**(1)),n2//(2**(1))), mode='bilinear')
135 | self.decode_net[7] = nn.Upsample(size=(n1,n2), mode='bilinear')
136 | return self.decode_net(x)
137 |
138 | def conv_block(self,in_channels, out_channels):
139 | conv = nn.Sequential(
140 | nn.Conv2d(
141 | int(in_channels), int(out_channels), 3, 1,
142 | padding=self._get_padding('SAME', (3, 3),),
143 | ),
144 | nn.BatchNorm2d(int(out_channels)),
145 | nn.ReLU(inplace=True),
146 |
147 | nn.Conv2d(
148 | int(out_channels), int(out_channels), 3, 1,
149 | padding=self._get_padding('SAME', (3, 3),),
150 | ),
151 | nn.BatchNorm2d(int(out_channels)),
152 | nn.ReLU(inplace=True)
153 | )
154 | return conv
155 |
156 | def _get_padding(self, padding_type, kernel_size):
157 | assert padding_type in ['SAME', 'VALID']
158 | if padding_type == 'SAME':
159 | _list = [(k - 1) // 2 for k in kernel_size]
160 | return tuple(_list)
161 | return tuple(0 for _ in kernel_size)
162 |
163 | class SETR_MLA(nn.Module):
164 | def __init__(self,embedding_dim,num_classes):
165 | super(SETR_MLA,self).__init__()
166 |
167 | self.embedding_dim = embedding_dim
168 | self.num_classes = num_classes
169 |
170 | self.net1_in, self.net1_intmd, self.net1_out = self._define_agg_net()
171 | self.net2_in, self.net2_intmd, self.net2_out = self._define_agg_net()
172 | self.net3_in, self.net3_intmd, self.net3_out = self._define_agg_net()
173 | self.net4_in, self.net4_intmd, self.net4_out = self._define_agg_net()
174 |
175 | self.output_net = IntermediateSequential(return_intermediate=False)
176 | self.output_net.add_module(
177 | "conv_1",
178 | nn.Conv2d(
179 | in_channels=self.embedding_dim, out_channels=self.num_classes,
180 | kernel_size=1, stride=1,
181 | padding=self._get_padding('VALID', (1, 1),),
182 | )
183 | )
184 | self.output_net.add_module(
185 | "upsample_1",
186 | nn.Upsample(size = (1,1), mode='bilinear')
187 | )
188 |
189 | def forward(self,x,size):
190 | n1,n2 = size
191 | self.output_net[-1] = nn.Upsample(size = (n1,n2), mode='bilinear')
192 | x3,x6,x9,x12 = x
193 |
194 | x12_intmd_in = self.net1_in(x12)
195 | x12_out = self.net1_out(x12_intmd_in)
196 |
197 | x9_in = self.net2_in(x9)
198 | x9_intmd_in = x9_in + x12_intmd_in
199 | x9_intmd_out = self.net2_intmd(x9_intmd_in)
200 | x9_out = self.net2_out(x9_intmd_out)
201 |
202 | x6_in = self.net3_in(x6)
203 | x6_intmd_in = x6_in + x9_intmd_in
204 | x6_intmd_out = self.net3_intmd(x6_intmd_in)
205 | x6_out = self.net3_out(x6_intmd_out)
206 |
207 | x3_in = self.net4_in(x3)
208 | x3_intmd_in = x3_in + x6_intmd_in
209 | x3_intmd_out = self.net4_intmd(x3_intmd_in)
210 | x3_out = self.net4_out(x3_intmd_out)
211 |
212 | out = torch.cat((x12_out, x9_out, x6_out, x3_out), dim=1)
213 | out = self.output_net(out)
214 |
215 | return out
216 |
217 | def conv_block(self,in_channels, out_channels):
218 | conv = nn.Sequential(
219 | nn.Conv2d(
220 | int(in_channels), int(out_channels), 3, 1,
221 | padding=self._get_padding('SAME', (3, 3),),
222 | ),
223 | nn.BatchNorm2d(int(out_channels)),
224 | nn.ReLU(inplace=True),
225 |
226 | nn.Conv2d(
227 | int(out_channels), int(out_channels), 3, 1,
228 | padding=self._get_padding('SAME', (3, 3),),
229 | ),
230 | nn.BatchNorm2d(int(out_channels)),
231 | nn.ReLU(inplace=True)
232 | )
233 | return conv
234 |
235 | def _define_agg_net(self):
236 | model_in = IntermediateSequential(return_intermediate=False)
237 | model_in.add_module(
238 | "layer_1",
239 | self.conv_block(self.embedding_dim,int(self.embedding_dim/2))
240 | )
241 |
242 | model_intmd = IntermediateSequential(return_intermediate=False)
243 | model_intmd.add_module(
244 | "layer_intmd",
245 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/2))
246 | )
247 |
248 | model_out = IntermediateSequential(return_intermediate=False)
249 | model_out.add_module(
250 | "layer_2",
251 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/2))
252 | )
253 | model_out.add_module(
254 | "layer_3",
255 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/4))
256 | )
257 | model_out.add_module(
258 | "upsample", nn.Upsample(scale_factor=4, mode='bilinear')
259 | )
260 | model_out.add_module(
261 | "layer_4",
262 | self.conv_block(int(self.embedding_dim/4),int(self.embedding_dim/4))
263 | )
264 | return model_in, model_intmd, model_out
265 |
266 | def _get_padding(self, padding_type, kernel_size):
267 | assert padding_type in ['SAME', 'VALID']
268 | if padding_type == 'SAME':
269 | _list = [(k - 1) // 2 for k in kernel_size]
270 | return tuple(_list)
271 | return tuple(0 for _ in kernel_size)
272 |
273 | class dinov2_pup(nn.Module):
274 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"):
275 | super(dinov2_pup,self).__init__()
276 |
277 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method)
278 | self.decoder = SETR_PUP(self.emb, num_classes)
279 |
280 | if frozen:
281 | for param in self.encoder.parameters():
282 | param.requires_grad = False
283 | else:
284 | if finetune_method == "unfrozen":
285 | for param in self.encoder.parameters():
286 | param.requires_grad = True
287 | elif finetune_method == "lora":
288 | lora.mark_only_lora_as_trainable(self.encoder)
289 |
290 | def forward(self,x,size):
291 | B,_,H,W = x.shape
292 | features,_ = self.encoder.forward_features(x)
293 | fea_img = features['x_norm_patchtokens']
294 | fea_img = fea_img.view(fea_img.size(0),int(H / 14),int(W / 14),self.emb)
295 | fea_img = fea_img.permute(0, 3, 1, 2).contiguous()
296 | out = self.decoder(fea_img,size)
297 | return out
298 |
299 | class dinov2_mla(nn.Module):
300 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"):
301 | super(dinov2_mla,self).__init__()
302 |
303 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method)
304 | self.decoder = SETR_MLA(self.emb, num_classes)
305 | if frozen:
306 | for param in self.encoder.parameters():
307 | param.requires_grad = False
308 | else:
309 | if finetune_method == "unfrozen":
310 | for param in self.encoder.parameters():
311 | param.requires_grad = True
312 | elif finetune_method == "lora":
313 | lora.mark_only_lora_as_trainable(self.encoder)
314 |
315 | def forward(self,x,size):
316 | B,_,H,W = x.shape
317 | _, x_middle = self.encoder.forward_features(x)
318 | xm = []
319 | for k,x in x_middle.items():
320 | x = x.view(
321 | x.size(0),
322 | int(H / 14),
323 | int(W / 14),
324 | self.emb,
325 | )
326 | x = x.permute(0, 3, 1, 2).contiguous()
327 | xm.append(x)
328 | out = self.decoder(xm,size)
329 | return out
330 |
331 | class dinov2_linear(nn.Module):
332 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"):
333 | super(dinov2_linear,self).__init__()
334 |
335 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method)
336 | self.decoder = nn.Conv2d(self.emb, num_classes, kernel_size=1)
337 |
338 | if frozen:
339 | for param in self.encoder.parameters():
340 | param.requires_grad = False
341 | else:
342 | if finetune_method == "unfrozen":
343 | for param in self.encoder.parameters():
344 | param.requires_grad = True
345 | elif finetune_method == "lora":
346 | lora.mark_only_lora_as_trainable(self.encoder)
347 |
348 | def forward(self,x,size):
349 | B,_,H,W = x.shape
350 | features,_ = self.encoder.forward_features(x)
351 | fea_img = features['x_norm_patchtokens']
352 | fea_img = fea_img.view(fea_img.size(0),int(H / 14),int(W / 14),self.emb)
353 | fea_img = fea_img.permute(0, 3, 1, 2).contiguous()
354 | out = self.decoder(fea_img)
355 | out = F.interpolate(out,size=size)
356 | return out
357 |
358 |
--------------------------------------------------------------------------------
/models/dpt/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import (
2 | FeatureFusionBlock,
3 | FeatureFusionBlock_custom,
4 | Interpolate,
5 | _make_scratch,
6 | )
7 | from .models import _make_fusion_block
--------------------------------------------------------------------------------
/models/dpt/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device("cpu"))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/models/dpt/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.dpt.vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 |
12 | def _make_encoder(
13 | backbone,
14 | features,
15 | use_pretrained,
16 | groups=1,
17 | expand=False,
18 | exportable=True,
19 | hooks=None,
20 | use_vit_only=False,
21 | use_readout="ignore",
22 | enable_attention_hooks=False,
23 | ):
24 | if backbone == "vitl16_384":
25 | pretrained = _make_pretrained_vitl16_384(
26 | use_pretrained,
27 | hooks=hooks,
28 | use_readout=use_readout,
29 | enable_attention_hooks=enable_attention_hooks,
30 | )
31 | scratch = _make_scratch(
32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
33 | ) # ViT-L/16 - 85.0% Top1 (backbone)
34 | elif backbone == "vitb_rn50_384":
35 | pretrained = _make_pretrained_vitb_rn50_384(
36 | use_pretrained,
37 | hooks=hooks,
38 | use_vit_only=use_vit_only,
39 | use_readout=use_readout,
40 | enable_attention_hooks=enable_attention_hooks,
41 | )
42 | scratch = _make_scratch(
43 | [256, 512, 768, 768], features, groups=groups, expand=expand
44 | ) # ViT-H/16 - 85.0% Top1 (backbone)
45 | elif backbone == "vitb16_384":
46 | pretrained = _make_pretrained_vitb16_384(
47 | use_pretrained,
48 | hooks=hooks,
49 | use_readout=use_readout,
50 | enable_attention_hooks=enable_attention_hooks,
51 | )
52 | scratch = _make_scratch(
53 | [96, 192, 384, 768], features, groups=groups, expand=expand
54 | ) # ViT-B/16 - 84.6% Top1 (backbone)
55 | elif backbone == "resnext101_wsl":
56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
57 | scratch = _make_scratch(
58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand
59 | ) # efficientnet_lite3
60 | else:
61 | print(f"Backbone '{backbone}' not implemented")
62 | assert False
63 |
64 | return pretrained, scratch
65 |
66 |
67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
68 | scratch = nn.Module()
69 |
70 | out_shape1 = out_shape
71 | out_shape2 = out_shape
72 | out_shape3 = out_shape
73 | out_shape4 = out_shape
74 | if expand == True:
75 | out_shape1 = out_shape
76 | out_shape2 = out_shape * 2
77 | out_shape3 = out_shape * 4
78 | out_shape4 = out_shape * 8
79 |
80 | scratch.layer1_rn = nn.Conv2d(
81 | in_shape[0],
82 | out_shape1,
83 | kernel_size=3,
84 | stride=1,
85 | padding=1,
86 | bias=False,
87 | groups=groups,
88 | )
89 | scratch.layer2_rn = nn.Conv2d(
90 | in_shape[1],
91 | out_shape2,
92 | kernel_size=3,
93 | stride=1,
94 | padding=1,
95 | bias=False,
96 | groups=groups,
97 | )
98 | scratch.layer3_rn = nn.Conv2d(
99 | in_shape[2],
100 | out_shape3,
101 | kernel_size=3,
102 | stride=1,
103 | padding=1,
104 | bias=False,
105 | groups=groups,
106 | )
107 | scratch.layer4_rn = nn.Conv2d(
108 | in_shape[3],
109 | out_shape4,
110 | kernel_size=3,
111 | stride=1,
112 | padding=1,
113 | bias=False,
114 | groups=groups,
115 | )
116 |
117 | return scratch
118 |
119 |
120 | def _make_resnet_backbone(resnet):
121 | pretrained = nn.Module()
122 | pretrained.layer1 = nn.Sequential(
123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
124 | )
125 |
126 | pretrained.layer2 = resnet.layer2
127 | pretrained.layer3 = resnet.layer3
128 | pretrained.layer4 = resnet.layer4
129 |
130 | return pretrained
131 |
132 |
133 | def _make_pretrained_resnext101_wsl(use_pretrained):
134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
135 | return _make_resnet_backbone(resnet)
136 |
137 |
138 | class Interpolate(nn.Module):
139 | """Interpolation module."""
140 |
141 | def __init__(self, scale_factor, mode, align_corners=False):
142 | """Init.
143 |
144 | Args:
145 | scale_factor (float): scaling
146 | mode (str): interpolation mode
147 | """
148 | super(Interpolate, self).__init__()
149 |
150 | self.interp = nn.functional.interpolate
151 | self.scale_factor = scale_factor
152 | self.mode = mode
153 | self.align_corners = align_corners
154 |
155 | def forward(self, x):
156 | """Forward pass.
157 |
158 | Args:
159 | x (tensor): input
160 |
161 | Returns:
162 | tensor: interpolated data
163 | """
164 |
165 | x = self.interp(
166 | x,
167 | scale_factor=self.scale_factor,
168 | mode=self.mode,
169 | align_corners=self.align_corners,
170 | )
171 |
172 | return x
173 |
174 |
175 | class ResidualConvUnit(nn.Module):
176 | """Residual convolution module."""
177 |
178 | def __init__(self, features):
179 | """Init.
180 |
181 | Args:
182 | features (int): number of features
183 | """
184 | super().__init__()
185 |
186 | self.conv1 = nn.Conv2d(
187 | features, features, kernel_size=3, stride=1, padding=1, bias=True
188 | )
189 |
190 | self.conv2 = nn.Conv2d(
191 | features, features, kernel_size=3, stride=1, padding=1, bias=True
192 | )
193 |
194 | self.relu = nn.ReLU(inplace=True)
195 |
196 | def forward(self, x):
197 | """Forward pass.
198 |
199 | Args:
200 | x (tensor): input
201 |
202 | Returns:
203 | tensor: output
204 | """
205 | out = self.relu(x)
206 | out = self.conv1(out)
207 | out = self.relu(out)
208 | out = self.conv2(out)
209 |
210 | return out + x
211 |
212 |
213 | class FeatureFusionBlock(nn.Module):
214 | """Feature fusion block."""
215 |
216 | def __init__(self, features):
217 | """Init.
218 |
219 | Args:
220 | features (int): number of features
221 | """
222 | super(FeatureFusionBlock, self).__init__()
223 |
224 | self.resConfUnit1 = ResidualConvUnit(features)
225 | self.resConfUnit2 = ResidualConvUnit(features)
226 |
227 | def forward(self, *xs):
228 | """Forward pass.
229 |
230 | Returns:
231 | tensor: output
232 | """
233 | output = xs[0]
234 |
235 | if len(xs) == 2:
236 | output += self.resConfUnit1(xs[1])
237 |
238 | output = self.resConfUnit2(output)
239 |
240 | output = nn.functional.interpolate(
241 | output, scale_factor=2, mode="bilinear", align_corners=True
242 | )
243 |
244 | return output
245 |
246 |
247 | class ResidualConvUnit_custom(nn.Module):
248 | """Residual convolution module."""
249 |
250 | def __init__(self, features, activation, bn):
251 | """Init.
252 |
253 | Args:
254 | features (int): number of features
255 | """
256 | super().__init__()
257 |
258 | self.bn = bn
259 |
260 | self.groups = 1
261 |
262 | self.conv1 = nn.Conv2d(
263 | features,
264 | features,
265 | kernel_size=3,
266 | stride=1,
267 | padding=1,
268 | bias=not self.bn,
269 | groups=self.groups,
270 | )
271 |
272 | self.conv2 = nn.Conv2d(
273 | features,
274 | features,
275 | kernel_size=3,
276 | stride=1,
277 | padding=1,
278 | bias=not self.bn,
279 | groups=self.groups,
280 | )
281 |
282 | if self.bn == True:
283 | self.bn1 = nn.BatchNorm2d(features)
284 | self.bn2 = nn.BatchNorm2d(features)
285 |
286 | self.activation = activation
287 |
288 | self.skip_add = nn.quantized.FloatFunctional()
289 |
290 | def forward(self, x):
291 | """Forward pass.
292 |
293 | Args:
294 | x (tensor): input
295 |
296 | Returns:
297 | tensor: output
298 | """
299 |
300 | out = self.activation(x)
301 | out = self.conv1(out)
302 | if self.bn == True:
303 | out = self.bn1(out)
304 |
305 | out = self.activation(out)
306 | out = self.conv2(out)
307 | if self.bn == True:
308 | out = self.bn2(out)
309 |
310 | if self.groups > 1:
311 | out = self.conv_merge(out)
312 |
313 | return self.skip_add.add(out, x)
314 |
315 | # return out + x
316 |
317 |
318 | class FeatureFusionBlock_custom(nn.Module):
319 | """Feature fusion block."""
320 |
321 | def __init__(
322 | self,
323 | features,
324 | activation,
325 | deconv=False,
326 | bn=False,
327 | expand=False,
328 | align_corners=True,
329 | ):
330 | """Init.
331 |
332 | Args:
333 | features (int): number of features
334 | """
335 | super(FeatureFusionBlock_custom, self).__init__()
336 |
337 | self.deconv = deconv
338 | self.align_corners = align_corners
339 |
340 | self.groups = 1
341 |
342 | self.expand = expand
343 | out_features = features
344 | if self.expand == True:
345 | out_features = features // 2
346 |
347 | self.out_conv = nn.Conv2d(
348 | features,
349 | out_features,
350 | kernel_size=1,
351 | stride=1,
352 | padding=0,
353 | bias=True,
354 | groups=1,
355 | )
356 |
357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
359 |
360 | self.skip_add = nn.quantized.FloatFunctional()
361 |
362 | def forward(self, size, *xs):
363 | """Forward pass.
364 |
365 | Returns:
366 | tensor: output
367 | """
368 | output = xs[0]
369 |
370 | if len(xs) == 2:
371 | res = self.resConfUnit1(xs[1])
372 | res = nn.functional.interpolate(
373 | res, size=(size[0]//2,size[1]//2), mode="bilinear", align_corners=self.align_corners
374 | )
375 | output = self.skip_add.add(output, res)
376 | # output += res
377 |
378 | output = self.resConfUnit2(output)
379 |
380 | output = nn.functional.interpolate(
381 | output, size=size, mode="bilinear", align_corners=self.align_corners
382 | )
383 |
384 | output = self.out_conv(output)
385 |
386 | return output
387 |
--------------------------------------------------------------------------------
/models/dpt/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention,MemEffAttention_lora
13 |
--------------------------------------------------------------------------------
/models/dpt/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 | import loralib as lora
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 | class Attention_lora(nn.Module):
65 | def __init__(
66 | self,
67 | dim: int,
68 | num_heads: int = 8,
69 | qkv_bias: bool = False,
70 | proj_bias: bool = True,
71 | attn_drop: float = 0.0,
72 | proj_drop: float = 0.0,
73 | ) -> None:
74 | super().__init__()
75 | self.num_heads = num_heads
76 | head_dim = dim // num_heads
77 | self.scale = head_dim**-0.5
78 |
79 | self.qkv = lora.Linear(dim, dim * 3, bias=qkv_bias, r=8)
80 | self.attn_drop = nn.Dropout(attn_drop)
81 | self.proj = lora.Linear(dim, dim, bias=proj_bias, r=8)
82 | self.proj_drop = nn.Dropout(proj_drop)
83 |
84 | def forward(self, x: Tensor) -> Tensor:
85 | B, N, C = x.shape
86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
87 |
88 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
89 | attn = q @ k.transpose(-2, -1)
90 |
91 | attn = attn.softmax(dim=-1)
92 | attn = self.attn_drop(attn)
93 |
94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
95 | x = self.proj(x)
96 | x = self.proj_drop(x)
97 | return x
98 |
99 | class MemEffAttention(Attention):
100 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
101 | if not XFORMERS_AVAILABLE:
102 | assert attn_bias is None, "xFormers is required for nested tensors usage"
103 | return super().forward(x)
104 |
105 | B, N, C = x.shape
106 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
107 |
108 | q, k, v = unbind(qkv, 2)
109 |
110 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
111 | x = x.reshape([B, N, C])
112 |
113 | x = self.proj(x)
114 | x = self.proj_drop(x)
115 | return x
116 |
117 | class MemEffAttention_lora(Attention_lora):
118 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
119 | if not XFORMERS_AVAILABLE:
120 | assert attn_bias is None, "xFormers is required for nested tensors usage"
121 | return super().forward(x)
122 |
123 | B, N, C = x.shape
124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
125 |
126 | q, k, v = unbind(qkv, 2)
127 |
128 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
129 | x = x.reshape([B, N, C])
130 |
131 | x = self.proj(x)
132 | x = self.proj_drop(x)
133 | return x
134 |
--------------------------------------------------------------------------------
/models/dpt/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/models/dpt/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/models/dpt/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/models/dpt/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/models/dpt/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/models/dpt/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/models/dpt/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/models/dpt/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from models.dpt.base_model import BaseModel
9 | from models.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_large(BaseModel):
13 | """Network for monocular depth estimation."""
14 |
15 | def __init__(self, path=None, features=256, non_negative=True):
16 | """Init.
17 |
18 | Args:
19 | path (str, optional): Path to saved model. Defaults to None.
20 | features (int, optional): Number of features. Defaults to 256.
21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
22 | """
23 | print("Loading weights: ", path)
24 |
25 | super(MidasNet_large, self).__init__()
26 |
27 | use_pretrained = False if path is None else True
28 |
29 | self.pretrained, self.scratch = _make_encoder(
30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained
31 | )
32 |
33 | self.scratch.refinenet4 = FeatureFusionBlock(features)
34 | self.scratch.refinenet3 = FeatureFusionBlock(features)
35 | self.scratch.refinenet2 = FeatureFusionBlock(features)
36 | self.scratch.refinenet1 = FeatureFusionBlock(features)
37 |
38 | self.scratch.output_conv = nn.Sequential(
39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
40 | Interpolate(scale_factor=2, mode="bilinear"),
41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
42 | nn.ReLU(True),
43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
44 | nn.ReLU(True) if non_negative else nn.Identity(),
45 | )
46 |
47 | if path:
48 | self.load(path)
49 |
50 | def forward(self, x):
51 | """Forward pass.
52 |
53 | Args:
54 | x (tensor): input data (image)
55 |
56 | Returns:
57 | tensor: depth
58 | """
59 |
60 | layer_1 = self.pretrained.layer1(x)
61 | layer_2 = self.pretrained.layer2(layer_1)
62 | layer_3 = self.pretrained.layer3(layer_2)
63 | layer_4 = self.pretrained.layer4(layer_3)
64 |
65 | layer_1_rn = self.scratch.layer1_rn(layer_1)
66 | layer_2_rn = self.scratch.layer2_rn(layer_2)
67 | layer_3_rn = self.scratch.layer3_rn(layer_3)
68 | layer_4_rn = self.scratch.layer4_rn(layer_4)
69 |
70 | path_4 = self.scratch.refinenet4(layer_4_rn)
71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
74 |
75 | out = self.scratch.output_conv(path_1)
76 |
77 | return torch.squeeze(out, dim=1)
78 |
--------------------------------------------------------------------------------
/models/dpt/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.dpt.base_model import BaseModel
6 | from models.dpt.blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | enable_attention_hooks=False,
36 | ):
37 |
38 | super(DPT, self).__init__()
39 |
40 | self.channels_last = channels_last
41 |
42 | hooks = {
43 | "vitb_rn50_384": [0, 1, 8, 11],
44 | "vitb16_384": [2, 5, 8, 11],
45 | "vitl16_384": [5, 11, 17, 23],
46 | }
47 |
48 | # Instantiate backbone and reassemble blocks
49 | self.pretrained, self.scratch = _make_encoder(
50 | backbone,
51 | features,
52 | False, # Set to true of you want to train from scratch, uses ImageNet weights
53 | groups=1,
54 | expand=False,
55 | exportable=False,
56 | hooks=hooks[backbone],
57 | use_readout=readout,
58 | enable_attention_hooks=enable_attention_hooks,
59 | )
60 |
61 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
63 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
64 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
65 |
66 | self.scratch.output_conv = head
67 |
68 | def forward(self, x):
69 | if self.channels_last == True:
70 | x.contiguous(memory_format=torch.channels_last)
71 |
72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
73 |
74 | layer_1_rn = self.scratch.layer1_rn(layer_1)
75 | layer_2_rn = self.scratch.layer2_rn(layer_2)
76 | layer_3_rn = self.scratch.layer3_rn(layer_3)
77 | layer_4_rn = self.scratch.layer4_rn(layer_4)
78 |
79 | path_4 = self.scratch.refinenet4(layer_4_rn)
80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
83 |
84 | out = self.scratch.output_conv(path_1)
85 |
86 | return out
87 |
88 |
89 | class DPTDepthModel(DPT):
90 | def __init__(
91 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
92 | ):
93 | features = kwargs["features"] if "features" in kwargs else 256
94 |
95 | self.scale = scale
96 | self.shift = shift
97 | self.invert = invert
98 |
99 | head = nn.Sequential(
100 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
101 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103 | nn.ReLU(True),
104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105 | nn.ReLU(True) if non_negative else nn.Identity(),
106 | nn.Identity(),
107 | )
108 |
109 | super().__init__(head, **kwargs)
110 |
111 | if path is not None:
112 | self.load(path)
113 |
114 | def forward(self, x):
115 | inv_depth = super().forward(x).squeeze(dim=1)
116 |
117 | if self.invert:
118 | depth = self.scale * inv_depth + self.shift
119 | depth[depth < 1e-8] = 1e-8
120 | depth = 1.0 / depth
121 | return depth
122 | else:
123 | return inv_depth
124 |
125 |
126 | class DPTSegmentationModel(DPT):
127 | def __init__(self, num_classes, path=None, **kwargs):
128 |
129 | features = kwargs["features"] if "features" in kwargs else 256
130 |
131 | kwargs["use_bn"] = True
132 |
133 | head = nn.Sequential(
134 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
135 | nn.BatchNorm2d(features),
136 | nn.ReLU(True),
137 | nn.Dropout(0.1, False),
138 | nn.Conv2d(features, num_classes, kernel_size=1),
139 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
140 | )
141 |
142 | super().__init__(head, **kwargs)
143 |
144 | self.auxlayer = nn.Sequential(
145 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
146 | nn.BatchNorm2d(features),
147 | nn.ReLU(True),
148 | nn.Dropout(0.1, False),
149 | nn.Conv2d(features, num_classes, kernel_size=1),
150 | )
151 |
152 | if path is not None:
153 | self.load(path)
154 |
--------------------------------------------------------------------------------
/models/dpt/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height)."""
50 |
51 | def __init__(
52 | self,
53 | width,
54 | height,
55 | resize_target=True,
56 | keep_aspect_ratio=False,
57 | ensure_multiple_of=1,
58 | resize_method="lower_bound",
59 | image_interpolation_method=cv2.INTER_AREA,
60 | ):
61 | """Init.
62 |
63 | Args:
64 | width (int): desired output width
65 | height (int): desired output height
66 | resize_target (bool, optional):
67 | True: Resize the full sample (image, mask, target).
68 | False: Resize image only.
69 | Defaults to True.
70 | keep_aspect_ratio (bool, optional):
71 | True: Keep the aspect ratio of the input sample.
72 | Output sample might not have the given width and height, and
73 | resize behaviour depends on the parameter 'resize_method'.
74 | Defaults to False.
75 | ensure_multiple_of (int, optional):
76 | Output width and height is constrained to be multiple of this parameter.
77 | Defaults to 1.
78 | resize_method (str, optional):
79 | "lower_bound": Output will be at least as large as the given size.
80 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
81 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
82 | Defaults to "lower_bound".
83 | """
84 | self.__width = width
85 | self.__height = height
86 |
87 | self.__resize_target = resize_target
88 | self.__keep_aspect_ratio = keep_aspect_ratio
89 | self.__multiple_of = ensure_multiple_of
90 | self.__resize_method = resize_method
91 | self.__image_interpolation_method = image_interpolation_method
92 |
93 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
94 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
95 |
96 | if max_val is not None and y > max_val:
97 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
98 |
99 | if y < min_val:
100 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
101 |
102 | return y
103 |
104 | def get_size(self, width, height):
105 | # determine new height and width
106 | scale_height = self.__height / height
107 | scale_width = self.__width / width
108 |
109 | if self.__keep_aspect_ratio:
110 | if self.__resize_method == "lower_bound":
111 | # scale such that output size is lower bound
112 | if scale_width > scale_height:
113 | # fit width
114 | scale_height = scale_width
115 | else:
116 | # fit height
117 | scale_width = scale_height
118 | elif self.__resize_method == "upper_bound":
119 | # scale such that output size is upper bound
120 | if scale_width < scale_height:
121 | # fit width
122 | scale_height = scale_width
123 | else:
124 | # fit height
125 | scale_width = scale_height
126 | elif self.__resize_method == "minimal":
127 | # scale as least as possbile
128 | if abs(1 - scale_width) < abs(1 - scale_height):
129 | # fit width
130 | scale_height = scale_width
131 | else:
132 | # fit height
133 | scale_width = scale_height
134 | else:
135 | raise ValueError(
136 | f"resize_method {self.__resize_method} not implemented"
137 | )
138 |
139 | if self.__resize_method == "lower_bound":
140 | new_height = self.constrain_to_multiple_of(
141 | scale_height * height, min_val=self.__height
142 | )
143 | new_width = self.constrain_to_multiple_of(
144 | scale_width * width, min_val=self.__width
145 | )
146 | elif self.__resize_method == "upper_bound":
147 | new_height = self.constrain_to_multiple_of(
148 | scale_height * height, max_val=self.__height
149 | )
150 | new_width = self.constrain_to_multiple_of(
151 | scale_width * width, max_val=self.__width
152 | )
153 | elif self.__resize_method == "minimal":
154 | new_height = self.constrain_to_multiple_of(scale_height * height)
155 | new_width = self.constrain_to_multiple_of(scale_width * width)
156 | else:
157 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
158 |
159 | return (new_width, new_height)
160 |
161 | def __call__(self, sample):
162 | width, height = self.get_size(
163 | sample["image"].shape[1], sample["image"].shape[0]
164 | )
165 |
166 | # resize sample
167 | sample["image"] = cv2.resize(
168 | sample["image"],
169 | (width, height),
170 | interpolation=self.__image_interpolation_method,
171 | )
172 |
173 | if self.__resize_target:
174 | if "disparity" in sample:
175 | sample["disparity"] = cv2.resize(
176 | sample["disparity"],
177 | (width, height),
178 | interpolation=cv2.INTER_NEAREST,
179 | )
180 |
181 | if "depth" in sample:
182 | sample["depth"] = cv2.resize(
183 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
184 | )
185 |
186 | sample["mask"] = cv2.resize(
187 | sample["mask"].astype(np.float32),
188 | (width, height),
189 | interpolation=cv2.INTER_NEAREST,
190 | )
191 | sample["mask"] = sample["mask"].astype(bool)
192 |
193 | return sample
194 |
195 |
196 | class NormalizeImage(object):
197 | """Normlize image by given mean and std."""
198 |
199 | def __init__(self, mean, std):
200 | self.__mean = mean
201 | self.__std = std
202 |
203 | def __call__(self, sample):
204 | sample["image"] = (sample["image"] - self.__mean) / self.__std
205 |
206 | return sample
207 |
208 |
209 | class PrepareForNet(object):
210 | """Prepare sample for usage as network input."""
211 |
212 | def __init__(self):
213 | pass
214 |
215 | def __call__(self, sample):
216 | image = np.transpose(sample["image"], (2, 0, 1))
217 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
218 |
219 | if "mask" in sample:
220 | sample["mask"] = sample["mask"].astype(np.float32)
221 | sample["mask"] = np.ascontiguousarray(sample["mask"])
222 |
223 | if "disparity" in sample:
224 | disparity = sample["disparity"].astype(np.float32)
225 | sample["disparity"] = np.ascontiguousarray(disparity)
226 |
227 | if "depth" in sample:
228 | depth = sample["depth"].astype(np.float32)
229 | sample["depth"] = np.ascontiguousarray(depth)
230 |
231 | return sample
232 |
--------------------------------------------------------------------------------
/models/dpt/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import timm
4 | import types
5 | import math
6 | import torch.nn.functional as F
7 |
8 |
9 | activations = {}
10 |
11 |
12 | def get_activation(name):
13 | def hook(model, input, output):
14 | activations[name] = output
15 |
16 | return hook
17 |
18 |
19 | attention = {}
20 |
21 |
22 | def get_attention(name):
23 | def hook(module, input, output):
24 | x = input[0]
25 | B, N, C = x.shape
26 | qkv = (
27 | module.qkv(x)
28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads)
29 | .permute(2, 0, 3, 1, 4)
30 | )
31 | q, k, v = (
32 | qkv[0],
33 | qkv[1],
34 | qkv[2],
35 | ) # make torchscript happy (cannot use tensor as tuple)
36 |
37 | attn = (q @ k.transpose(-2, -1)) * module.scale
38 |
39 | attn = attn.softmax(dim=-1) # [:,:,1,1:]
40 | attention[name] = attn
41 |
42 | return hook
43 |
44 |
45 | def get_mean_attention_map(attn, token, shape):
46 | attn = attn[:, :, token, 1:]
47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float()
48 | attn = torch.nn.functional.interpolate(
49 | attn, size=shape[2:], mode="bicubic", align_corners=False
50 | ).squeeze(0)
51 |
52 | all_attn = torch.mean(attn, 0)
53 |
54 | return all_attn
55 |
56 |
57 | class Slice(nn.Module):
58 | def __init__(self, start_index=1):
59 | super(Slice, self).__init__()
60 | self.start_index = start_index
61 |
62 | def forward(self, x):
63 | return x[:, self.start_index :]
64 |
65 |
66 | class AddReadout(nn.Module):
67 | def __init__(self, start_index=1):
68 | super(AddReadout, self).__init__()
69 | self.start_index = start_index
70 |
71 | def forward(self, x):
72 | if self.start_index == 2:
73 | readout = (x[:, 0] + x[:, 1]) / 2
74 | else:
75 | readout = x[:, 0]
76 | return x[:, self.start_index :] + readout.unsqueeze(1)
77 |
78 |
79 | class ProjectReadout(nn.Module):
80 | def __init__(self, in_features, start_index=1):
81 | super(ProjectReadout, self).__init__()
82 | self.start_index = start_index
83 |
84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
85 |
86 | def forward(self, x):
87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
88 | features = torch.cat((x[:, self.start_index :], readout), -1)
89 |
90 | return self.project(features)
91 |
92 |
93 | class Transpose(nn.Module):
94 | def __init__(self, dim0, dim1):
95 | super(Transpose, self).__init__()
96 | self.dim0 = dim0
97 | self.dim1 = dim1
98 |
99 | def forward(self, x):
100 | x = x.transpose(self.dim0, self.dim1)
101 | return x
102 |
103 |
104 | def forward_vit(pretrained, x):
105 | b, c, h, w = x.shape
106 |
107 | glob = pretrained.model.forward_flex(x)
108 |
109 | layer_1 = pretrained.activations["1"]
110 | layer_2 = pretrained.activations["2"]
111 | layer_3 = pretrained.activations["3"]
112 | layer_4 = pretrained.activations["4"]
113 |
114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1)
115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2)
116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3)
117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4)
118 |
119 | unflatten = nn.Sequential(
120 | nn.Unflatten(
121 | 2,
122 | torch.Size(
123 | [
124 | h // pretrained.model.patch_size[1],
125 | w // pretrained.model.patch_size[0],
126 | ]
127 | ),
128 | )
129 | )
130 |
131 | if layer_1.ndim == 3:
132 | layer_1 = unflatten(layer_1)
133 | if layer_2.ndim == 3:
134 | layer_2 = unflatten(layer_2)
135 | if layer_3.ndim == 3:
136 | layer_3 = unflatten(layer_3)
137 | if layer_4.ndim == 3:
138 | layer_4 = unflatten(layer_4)
139 |
140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
144 |
145 | return layer_1, layer_2, layer_3, layer_4
146 |
147 |
148 | def _resize_pos_embed(self, posemb, gs_h, gs_w):
149 | posemb_tok, posemb_grid = (
150 | posemb[:, : self.start_index],
151 | posemb[0, self.start_index :],
152 | )
153 |
154 | gs_old = int(math.sqrt(len(posemb_grid)))
155 |
156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
159 |
160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
161 |
162 | return posemb
163 |
164 |
165 | def forward_flex(self, x):
166 | b, c, h, w = x.shape
167 |
168 | pos_embed = self._resize_pos_embed(
169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
170 | )
171 |
172 | B = x.shape[0]
173 |
174 | if hasattr(self.patch_embed, "backbone"):
175 | x = self.patch_embed.backbone(x)
176 | if isinstance(x, (list, tuple)):
177 | x = x[-1] # last feature if backbone outputs list/tuple of features
178 |
179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
180 |
181 | if getattr(self, "dist_token", None) is not None:
182 | cls_tokens = self.cls_token.expand(
183 | B, -1, -1
184 | ) # stole cls_tokens impl from Phil Wang, thanks
185 | dist_token = self.dist_token.expand(B, -1, -1)
186 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
187 | else:
188 | cls_tokens = self.cls_token.expand(
189 | B, -1, -1
190 | ) # stole cls_tokens impl from Phil Wang, thanks
191 | x = torch.cat((cls_tokens, x), dim=1)
192 |
193 | x = x + pos_embed
194 | x = self.pos_drop(x)
195 |
196 | for blk in self.blocks:
197 | x = blk(x)
198 |
199 | x = self.norm(x)
200 |
201 | return x
202 |
203 |
204 | def get_readout_oper(vit_features, features, use_readout, start_index=1):
205 | if use_readout == "ignore":
206 | readout_oper = [Slice(start_index)] * len(features)
207 | elif use_readout == "add":
208 | readout_oper = [AddReadout(start_index)] * len(features)
209 | elif use_readout == "project":
210 | readout_oper = [
211 | ProjectReadout(vit_features, start_index) for out_feat in features
212 | ]
213 | else:
214 | assert (
215 | False
216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
217 |
218 | return readout_oper
219 |
220 |
221 | def _make_vit_b16_backbone(
222 | model,
223 | features=[96, 192, 384, 768],
224 | size=[384, 384],
225 | hooks=[2, 5, 8, 11],
226 | vit_features=768,
227 | use_readout="ignore",
228 | start_index=1,
229 | enable_attention_hooks=False,
230 | ):
231 | pretrained = nn.Module()
232 |
233 | pretrained.model = model
234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
238 |
239 | pretrained.activations = activations
240 |
241 | if enable_attention_hooks:
242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook(
243 | get_attention("attn_1")
244 | )
245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook(
246 | get_attention("attn_2")
247 | )
248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook(
249 | get_attention("attn_3")
250 | )
251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook(
252 | get_attention("attn_4")
253 | )
254 | pretrained.attention = attention
255 |
256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
257 |
258 | # 32, 48, 136, 384
259 | pretrained.act_postprocess1 = nn.Sequential(
260 | readout_oper[0],
261 | Transpose(1, 2),
262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
263 | nn.Conv2d(
264 | in_channels=vit_features,
265 | out_channels=features[0],
266 | kernel_size=1,
267 | stride=1,
268 | padding=0,
269 | ),
270 | nn.ConvTranspose2d(
271 | in_channels=features[0],
272 | out_channels=features[0],
273 | kernel_size=4,
274 | stride=4,
275 | padding=0,
276 | bias=True,
277 | dilation=1,
278 | groups=1,
279 | ),
280 | )
281 |
282 | pretrained.act_postprocess2 = nn.Sequential(
283 | readout_oper[1],
284 | Transpose(1, 2),
285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
286 | nn.Conv2d(
287 | in_channels=vit_features,
288 | out_channels=features[1],
289 | kernel_size=1,
290 | stride=1,
291 | padding=0,
292 | ),
293 | nn.ConvTranspose2d(
294 | in_channels=features[1],
295 | out_channels=features[1],
296 | kernel_size=2,
297 | stride=2,
298 | padding=0,
299 | bias=True,
300 | dilation=1,
301 | groups=1,
302 | ),
303 | )
304 |
305 | pretrained.act_postprocess3 = nn.Sequential(
306 | readout_oper[2],
307 | Transpose(1, 2),
308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
309 | nn.Conv2d(
310 | in_channels=vit_features,
311 | out_channels=features[2],
312 | kernel_size=1,
313 | stride=1,
314 | padding=0,
315 | ),
316 | )
317 |
318 | pretrained.act_postprocess4 = nn.Sequential(
319 | readout_oper[3],
320 | Transpose(1, 2),
321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
322 | nn.Conv2d(
323 | in_channels=vit_features,
324 | out_channels=features[3],
325 | kernel_size=1,
326 | stride=1,
327 | padding=0,
328 | ),
329 | nn.Conv2d(
330 | in_channels=features[3],
331 | out_channels=features[3],
332 | kernel_size=3,
333 | stride=2,
334 | padding=1,
335 | ),
336 | )
337 |
338 | pretrained.model.start_index = start_index
339 | pretrained.model.patch_size = [16, 16]
340 |
341 | # We inject this function into the VisionTransformer instances so that
342 | # we can use it with interpolated position embeddings without modifying the library source.
343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
344 | pretrained.model._resize_pos_embed = types.MethodType(
345 | _resize_pos_embed, pretrained.model
346 | )
347 |
348 | return pretrained
349 |
350 |
351 | def _make_vit_b_rn50_backbone(
352 | model,
353 | features=[256, 512, 768, 768],
354 | size=[384, 384],
355 | hooks=[0, 1, 8, 11],
356 | vit_features=768,
357 | use_vit_only=False,
358 | use_readout="ignore",
359 | start_index=1,
360 | enable_attention_hooks=False,
361 | ):
362 | pretrained = nn.Module()
363 |
364 | pretrained.model = model
365 |
366 | if use_vit_only == True:
367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
369 | else:
370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
371 | get_activation("1")
372 | )
373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
374 | get_activation("2")
375 | )
376 |
377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
379 |
380 | if enable_attention_hooks:
381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1"))
382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2"))
383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3"))
384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4"))
385 | pretrained.attention = attention
386 |
387 | pretrained.activations = activations
388 |
389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
390 |
391 | if use_vit_only == True:
392 | pretrained.act_postprocess1 = nn.Sequential(
393 | readout_oper[0],
394 | Transpose(1, 2),
395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
396 | nn.Conv2d(
397 | in_channels=vit_features,
398 | out_channels=features[0],
399 | kernel_size=1,
400 | stride=1,
401 | padding=0,
402 | ),
403 | nn.ConvTranspose2d(
404 | in_channels=features[0],
405 | out_channels=features[0],
406 | kernel_size=4,
407 | stride=4,
408 | padding=0,
409 | bias=True,
410 | dilation=1,
411 | groups=1,
412 | ),
413 | )
414 |
415 | pretrained.act_postprocess2 = nn.Sequential(
416 | readout_oper[1],
417 | Transpose(1, 2),
418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
419 | nn.Conv2d(
420 | in_channels=vit_features,
421 | out_channels=features[1],
422 | kernel_size=1,
423 | stride=1,
424 | padding=0,
425 | ),
426 | nn.ConvTranspose2d(
427 | in_channels=features[1],
428 | out_channels=features[1],
429 | kernel_size=2,
430 | stride=2,
431 | padding=0,
432 | bias=True,
433 | dilation=1,
434 | groups=1,
435 | ),
436 | )
437 | else:
438 | pretrained.act_postprocess1 = nn.Sequential(
439 | nn.Identity(), nn.Identity(), nn.Identity()
440 | )
441 | pretrained.act_postprocess2 = nn.Sequential(
442 | nn.Identity(), nn.Identity(), nn.Identity()
443 | )
444 |
445 | pretrained.act_postprocess3 = nn.Sequential(
446 | readout_oper[2],
447 | Transpose(1, 2),
448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
449 | nn.Conv2d(
450 | in_channels=vit_features,
451 | out_channels=features[2],
452 | kernel_size=1,
453 | stride=1,
454 | padding=0,
455 | ),
456 | )
457 |
458 | pretrained.act_postprocess4 = nn.Sequential(
459 | readout_oper[3],
460 | Transpose(1, 2),
461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
462 | nn.Conv2d(
463 | in_channels=vit_features,
464 | out_channels=features[3],
465 | kernel_size=1,
466 | stride=1,
467 | padding=0,
468 | ),
469 | nn.Conv2d(
470 | in_channels=features[3],
471 | out_channels=features[3],
472 | kernel_size=3,
473 | stride=2,
474 | padding=1,
475 | ),
476 | )
477 |
478 | pretrained.model.start_index = start_index
479 | pretrained.model.patch_size = [16, 16]
480 |
481 | # We inject this function into the VisionTransformer instances so that
482 | # we can use it with interpolated position embeddings without modifying the library source.
483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
484 |
485 | # We inject this function into the VisionTransformer instances so that
486 | # we can use it with interpolated position embeddings without modifying the library source.
487 | pretrained.model._resize_pos_embed = types.MethodType(
488 | _resize_pos_embed, pretrained.model
489 | )
490 |
491 | return pretrained
492 |
493 |
494 | def _make_pretrained_vitb_rn50_384(
495 | pretrained,
496 | use_readout="ignore",
497 | hooks=None,
498 | use_vit_only=False,
499 | enable_attention_hooks=False,
500 | ):
501 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
502 |
503 | hooks = [0, 1, 8, 11] if hooks == None else hooks
504 | return _make_vit_b_rn50_backbone(
505 | model,
506 | features=[256, 512, 768, 768],
507 | size=[384, 384],
508 | hooks=hooks,
509 | use_vit_only=use_vit_only,
510 | use_readout=use_readout,
511 | enable_attention_hooks=enable_attention_hooks,
512 | )
513 |
514 |
515 | def _make_pretrained_vitl16_384(
516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
517 | ):
518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
519 |
520 | hooks = [5, 11, 17, 23] if hooks == None else hooks
521 | return _make_vit_b16_backbone(
522 | model,
523 | features=[256, 512, 1024, 1024],
524 | hooks=hooks,
525 | vit_features=1024,
526 | use_readout=use_readout,
527 | enable_attention_hooks=enable_attention_hooks,
528 | )
529 |
530 |
531 | def _make_pretrained_vitb16_384(
532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
533 | ):
534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
535 |
536 | hooks = [2, 5, 8, 11] if hooks == None else hooks
537 | return _make_vit_b16_backbone(
538 | model,
539 | features=[96, 192, 384, 768],
540 | hooks=hooks,
541 | use_readout=use_readout,
542 | enable_attention_hooks=enable_attention_hooks,
543 | )
544 |
545 |
546 | def _make_pretrained_deitb16_384(
547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
548 | ):
549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
550 |
551 | hooks = [2, 5, 8, 11] if hooks == None else hooks
552 | return _make_vit_b16_backbone(
553 | model,
554 | features=[96, 192, 384, 768],
555 | hooks=hooks,
556 | use_readout=use_readout,
557 | enable_attention_hooks=enable_attention_hooks,
558 | )
559 |
560 |
561 | def _make_pretrained_deitb16_distil_384(
562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False
563 | ):
564 | model = timm.create_model(
565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained
566 | )
567 |
568 | hooks = [2, 5, 8, 11] if hooks == None else hooks
569 | return _make_vit_b16_backbone(
570 | model,
571 | features=[96, 192, 384, 768],
572 | hooks=hooks,
573 | use_readout=use_readout,
574 | start_index=2,
575 | enable_attention_hooks=enable_attention_hooks,
576 | )
577 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention,MemEffAttention_lora
13 |
--------------------------------------------------------------------------------
/models/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 | import loralib as lora
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 | class Attention_lora(nn.Module):
65 | def __init__(
66 | self,
67 | dim: int,
68 | num_heads: int = 8,
69 | qkv_bias: bool = False,
70 | proj_bias: bool = True,
71 | attn_drop: float = 0.0,
72 | proj_drop: float = 0.0,
73 | ) -> None:
74 | super().__init__()
75 | self.num_heads = num_heads
76 | head_dim = dim // num_heads
77 | self.scale = head_dim**-0.5
78 |
79 | self.qkv = lora.Linear(dim, dim * 3, bias=qkv_bias, r=8)
80 | self.attn_drop = nn.Dropout(attn_drop)
81 | self.proj = lora.Linear(dim, dim, bias=proj_bias, r=8)
82 | self.proj_drop = nn.Dropout(proj_drop)
83 |
84 | def forward(self, x: Tensor) -> Tensor:
85 | B, N, C = x.shape
86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
87 |
88 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
89 | attn = q @ k.transpose(-2, -1)
90 |
91 | attn = attn.softmax(dim=-1)
92 | attn = self.attn_drop(attn)
93 |
94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
95 | x = self.proj(x)
96 | x = self.proj_drop(x)
97 | return x
98 |
99 | class MemEffAttention(Attention):
100 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
101 | if not XFORMERS_AVAILABLE:
102 | assert attn_bias is None, "xFormers is required for nested tensors usage"
103 | return super().forward(x)
104 |
105 | B, N, C = x.shape
106 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
107 |
108 | q, k, v = unbind(qkv, 2)
109 |
110 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
111 | x = x.reshape([B, N, C])
112 |
113 | x = self.proj(x)
114 | x = self.proj_drop(x)
115 | return x
116 |
117 | class MemEffAttention_lora(Attention_lora):
118 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
119 | if not XFORMERS_AVAILABLE:
120 | assert attn_bias is None, "xFormers is required for nested tensors usage"
121 | return super().forward(x)
122 |
123 | B, N, C = x.shape
124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
125 |
126 | q, k, v = unbind(qkv, 2)
127 |
128 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
129 | x = x.reshape([B, N, C])
130 |
131 | x = self.proj(x)
132 | x = self.proj_drop(x)
133 | return x
134 |
--------------------------------------------------------------------------------
/models/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/models/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/models/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/models/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/models/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/models/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/models/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class DoubleConv(nn.Module):
7 | """(convolution => [BN] => ReLU) * 2"""
8 |
9 | def __init__(self, in_channels, out_channels, mid_channels=None):
10 | super().__init__()
11 | if not mid_channels:
12 | mid_channels = out_channels
13 | self.double_conv = nn.Sequential(
14 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
15 | nn.BatchNorm2d(mid_channels),
16 | nn.ReLU(inplace=True),
17 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
18 | nn.BatchNorm2d(out_channels),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 | def forward(self, x):
23 | return self.double_conv(x)
24 |
25 | class Down(nn.Module):
26 | """Downscaling with maxpool then double conv"""
27 |
28 | def __init__(self, in_channels, out_channels):
29 | super().__init__()
30 | self.maxpool_conv = nn.Sequential(
31 | nn.MaxPool2d(2),
32 | DoubleConv(in_channels, out_channels)
33 | )
34 |
35 | def forward(self, x):
36 | return self.maxpool_conv(x)
37 |
38 | class Up(nn.Module):
39 | """Upscaling then double conv"""
40 |
41 | def __init__(self, in_channels, out_channels, size, bilinear=True):
42 | super().__init__()
43 |
44 | # if bilinear, use the normal convolutions to reduce the number of channels
45 | if bilinear:
46 | self.up = nn.Upsample(size=size, mode='bilinear', align_corners=True)
47 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
48 | else:
49 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
50 | self.conv = DoubleConv(in_channels, out_channels)
51 |
52 | def forward(self, x1, x2):
53 | x1 = self.up(x1)
54 | # input is CHW
55 | diffY = x2.size()[2] - x1.size()[2]
56 | diffX = x2.size()[3] - x1.size()[3]
57 |
58 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
59 | diffY // 2, diffY - diffY // 2])
60 | # if you have padding issues, see
61 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
62 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
63 | x = torch.cat([x2, x1], dim=1)
64 | return self.conv(x)
65 |
66 | class Upconv(nn.Module):
67 | """Upscaling then double conv"""
68 |
69 | def __init__(self, in_channels, out_channels):
70 | super().__init__()
71 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
72 |
73 | def forward(self, x1, x2):
74 | # input is CHW
75 | diffY = x2.size()[2] - x1.size()[2]
76 | diffX = x2.size()[3] - x1.size()[3]
77 |
78 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
79 | diffY // 2, diffY - diffY // 2])
80 | # if you have padding issues, see
81 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
82 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
83 | x = torch.cat([x2, x1], dim=1)
84 | return self.conv(x)
85 |
86 | class OutConv(nn.Module):
87 | def __init__(self, in_channels, out_channels):
88 | super(OutConv, self).__init__()
89 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
90 |
91 | def forward(self, x):
92 | return self.conv(x)
93 |
94 | class U_Net(nn.Module):
95 | def __init__(self, n_channels, n_classes, bilinear=True, outScale=False):
96 | super(U_Net, self).__init__()
97 | self.n_channels = n_channels
98 | self.n_classes = n_classes
99 | self.bilinear = bilinear
100 |
101 | self.inc = (DoubleConv(n_channels, 32))
102 | self.down1 = (Down(32, 64))
103 | self.down2 = (Down(64, 128))
104 | self.down3 = (Down(128, 256))
105 | factor = 2 if bilinear else 1
106 | self.down4 = (Down(256, 512 // factor))
107 | self.up1 = (Upconv(512, 256 // factor))
108 | self.up2 = (Upconv(256, 128 // factor))
109 | self.up3 = (Upconv(128, 64 // factor))
110 | self.up4 = (Upconv(64, 32))
111 | self.outc = (OutConv(32, n_classes))
112 | self.outScale = outScale
113 |
114 | def forward(self, x, size):
115 | n1,n2 = size
116 | x1 = self.inc(x)
117 | x2 = self.down1(x1)
118 | x3 = self.down2(x2)
119 | x4 = self.down3(x3)
120 | x5 = self.down4(x4)
121 | x = self.up((n1//8,n2//8),x5)
122 | x = self.up1(x5, x4)
123 | x = self.up((n1//4,n2//4),x)
124 | x = self.up2(x, x3)
125 | x = self.up((n1//2,n2//2),x)
126 | x = self.up3(x, x2)
127 | x = self.up((n1 ,n2 ),x)
128 | x = self.up4(x, x1)
129 | logits = self.outc(x)
130 | if self.outScale:
131 | logits = F.relu(logits)+7.1
132 | return logits
133 |
134 | def up(self, size, x):
135 | up = nn.Upsample(size=size, mode='bilinear', align_corners=True)
136 | return up(x)
137 |
138 | if __name__ == "__main__":
139 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
140 | model = U_Net(1,6,224,224)
141 | model.eval()
142 | model.cuda()
143 | imgs_tensor = torch.zeros(5, 1, 224, 224).cuda()
144 |
145 | out = model(imgs_tensor)
146 | print(out.shape)
--------------------------------------------------------------------------------
/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | from functools import partial
12 | import math
13 | import logging
14 | from typing import Sequence, Tuple, Union, Callable
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.utils.checkpoint
19 | from torch.nn.init import trunc_normal_
20 |
21 | from models.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22 |
23 |
24 | logger = logging.getLogger("dinov2")
25 |
26 |
27 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28 | if not depth_first and include_root:
29 | fn(module=module, name=name)
30 | for child_name, child_module in module.named_children():
31 | child_name = ".".join((name, child_name)) if name else child_name
32 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33 | if depth_first and include_root:
34 | fn(module=module, name=name)
35 | return module
36 |
37 |
38 | class BlockChunk(nn.ModuleList):
39 | def forward(self, x):
40 | for b in self:
41 | x = b(x)
42 | return x
43 |
44 |
45 | class DinoVisionTransformer(nn.Module):
46 | def __init__(
47 | self,
48 | img_size=224,
49 | patch_size=16,
50 | in_chans=3,
51 | embed_dim=768,
52 | depth=12,
53 | num_heads=12,
54 | mlp_ratio=4.0,
55 | qkv_bias=True,
56 | ffn_bias=True,
57 | proj_bias=True,
58 | drop_path_rate=0.0,
59 | drop_path_uniform=False,
60 | init_values=None, # for layerscale: None or 0 => no layerscale
61 | embed_layer=PatchEmbed,
62 | act_layer=nn.GELU,
63 | block_fn=Block,
64 | ffn_layer="mlp",
65 | block_chunks=1
66 | ):
67 | """
68 | Args:
69 | img_size (int, tuple): input image size
70 | patch_size (int, tuple): patch size
71 | in_chans (int): number of input channels
72 | embed_dim (int): embedding dimension
73 | depth (int): depth of transformer
74 | num_heads (int): number of attention heads
75 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
76 | qkv_bias (bool): enable bias for qkv if True
77 | proj_bias (bool): enable bias for proj in attn if True
78 | ffn_bias (bool): enable bias for ffn if True
79 | drop_path_rate (float): stochastic depth rate
80 | drop_path_uniform (bool): apply uniform drop rate across blocks
81 | weight_init (str): weight init scheme
82 | init_values (float): layer-scale init values
83 | embed_layer (nn.Module): patch embedding layer
84 | act_layer (nn.Module): MLP activation layer
85 | block_fn (nn.Module): transformer block class
86 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
87 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
88 | """
89 | super().__init__()
90 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
91 |
92 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
93 | self.num_tokens = 1
94 | self.n_blocks = depth
95 | self.num_heads = num_heads
96 | self.patch_size = patch_size
97 |
98 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
99 | num_patches = self.patch_embed.num_patches
100 |
101 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
102 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1114, embed_dim))
103 |
104 | if drop_path_uniform is True:
105 | dpr = [drop_path_rate] * depth
106 | else:
107 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
108 |
109 | if ffn_layer == "mlp":
110 | logger.info("using MLP layer as FFN")
111 | ffn_layer = Mlp
112 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
113 | logger.info("using SwiGLU layer as FFN")
114 | ffn_layer = SwiGLUFFNFused
115 | elif ffn_layer == "identity":
116 | logger.info("using Identity layer as FFN")
117 |
118 | def f(*args, **kwargs):
119 | return nn.Identity()
120 |
121 | ffn_layer = f
122 | else:
123 | raise NotImplementedError
124 |
125 | blocks_list = [
126 | block_fn(
127 | dim=embed_dim,
128 | num_heads=num_heads,
129 | mlp_ratio=mlp_ratio,
130 | qkv_bias=qkv_bias,
131 | proj_bias=proj_bias,
132 | ffn_bias=ffn_bias,
133 | drop_path=dpr[i],
134 | norm_layer=norm_layer,
135 | act_layer=act_layer,
136 | ffn_layer=ffn_layer,
137 | init_values=init_values,
138 | )
139 | for i in range(depth)
140 | ]
141 | if block_chunks > 0:
142 | self.chunked_blocks = True
143 | chunked_blocks = []
144 | chunksize = depth // block_chunks
145 | for i in range(0, depth, chunksize):
146 | # this is to keep the block index consistent if we chunk the block list
147 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
148 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
149 | else:
150 | self.chunked_blocks = False
151 | self.blocks = nn.ModuleList(blocks_list)
152 |
153 | self.norm = norm_layer(embed_dim)
154 | self.head = nn.Identity()
155 |
156 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
157 |
158 | self.init_weights()
159 |
160 | def init_weights(self):
161 | trunc_normal_(self.pos_embed, std=0.02)
162 | nn.init.normal_(self.cls_token, std=1e-6)
163 | named_apply(init_weights_vit_timm, self)
164 |
165 | def interpolate_pos_encoding(self, x, w, h):
166 | previous_dtype = x.dtype
167 | npatch = x.shape[1] - 1
168 | N = self.pos_embed.shape[1] - 1
169 | if npatch == N and w == h:
170 | return self.pos_embed
171 | pos_embed = self.pos_embed.float()
172 | class_pos_embed = pos_embed[:, 0]
173 | patch_pos_embed = pos_embed[:, 1:]
174 | dim = x.shape[-1]
175 | w0 = w // self.patch_size
176 | h0 = h // self.patch_size
177 | # we add a small number to avoid floating point error in the interpolation
178 | # see discussion at https://github.com/facebookresearch/dino/issues/8
179 | w0, h0 = w0 + 0.1, h0 + 0.1
180 |
181 | patch_pos_embed = nn.functional.interpolate(
182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
184 | mode="bicubic",
185 | )
186 |
187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
190 |
191 | def prepare_tokens_with_masks(self, x, masks=None):
192 | B, nc, w, h = x.shape
193 | x = self.patch_embed(x)
194 | if masks is not None:
195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
196 |
197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
198 | x = x + self.interpolate_pos_encoding(x, w, h)
199 |
200 | return x
201 |
202 | def forward_features_list(self, x_list, masks_list):
203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
204 |
205 | for blk in self.blocks:
206 | x = blk(x)
207 |
208 | all_x = x
209 | output = []
210 | for x, masks in zip(all_x, masks_list):
211 | x_norm = self.norm(x)
212 | output.append(
213 | {
214 | "x_norm_clstoken": x_norm[:, 0],
215 | "x_norm_patchtokens": x_norm[:, 1:],
216 | "x_prenorm": x,
217 | "masks": masks,
218 | }
219 | )
220 | return output
221 |
222 | def forward_features(self, x, masks=None):
223 | if isinstance(x, list):
224 | return self.forward_features_list(x, masks)
225 |
226 | x = self.prepare_tokens_with_masks(x, masks)
227 |
228 | count = 1
229 | x_middle = {}
230 | for blk in self.blocks:
231 | x = blk(x)
232 | if count == 3 or count == 6 or count == 9 or count == 12:
233 | x_middle[str(count)] = self.norm(x)[:, 1:]
234 | count = count + 1
235 |
236 | x_norm = self.norm(x)
237 | return {
238 | "x_norm_clstoken": x_norm[:, 0],
239 | "x_norm_patchtokens": x_norm[:, 1:],
240 | "x_prenorm": x,
241 | "masks": masks,
242 | }, x_middle
243 |
244 | def _get_intermediate_layers_not_chunked(self, x, n=1):
245 | x = self.prepare_tokens_with_masks(x)
246 | # If n is an int, take the n last blocks. If it's a list, take them
247 | output, total_block_len = [], len(self.blocks)
248 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
249 | for i, blk in enumerate(self.blocks):
250 | x = blk(x)
251 | if i in blocks_to_take:
252 | output.append(x)
253 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
254 | return output
255 |
256 | def _get_intermediate_layers_chunked(self, x, n=1):
257 | x = self.prepare_tokens_with_masks(x)
258 | output, i, total_block_len = [], 0, len(self.blocks[-1])
259 | # If n is an int, take the n last blocks. If it's a list, take them
260 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
261 | for block_chunk in self.blocks:
262 | for blk in block_chunk[i:]: # Passing the nn.Identity()
263 | x = blk(x)
264 | if i in blocks_to_take:
265 | output.append(x)
266 | i += 1
267 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
268 | return output
269 |
270 | def get_intermediate_layers(
271 | self,
272 | x: torch.Tensor,
273 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
274 | reshape: bool = False,
275 | return_class_token: bool = False,
276 | norm=True,
277 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
278 | if self.chunked_blocks:
279 | outputs = self._get_intermediate_layers_chunked(x, n)
280 | else:
281 | outputs = self._get_intermediate_layers_not_chunked(x, n)
282 | if norm:
283 | outputs = [self.norm(out) for out in outputs]
284 | class_tokens = [out[:, 0] for out in outputs]
285 | outputs = [out[:, 1:] for out in outputs]
286 | if reshape:
287 | B, _, w, h = x.shape
288 | outputs = [
289 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
290 | for out in outputs
291 | ]
292 | if return_class_token:
293 | return tuple(zip(outputs, class_tokens))
294 | return tuple(outputs)
295 |
296 | def forward(self, *args, is_training=False, **kwargs):
297 | ret = self.forward_features(*args, **kwargs)
298 | if is_training:
299 | return ret
300 | else:
301 | return self.head(ret["x_norm_clstoken"])
302 |
303 |
304 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
305 | """ViT weight initialization, original timm impl (for reproducibility)"""
306 | if isinstance(module, nn.Linear):
307 | trunc_normal_(module.weight, std=0.02)
308 | if module.bias is not None:
309 | nn.init.zeros_(module.bias)
310 |
311 |
312 | def vit_small(patch_size=16, **kwargs):
313 | model = DinoVisionTransformer(
314 | patch_size=patch_size,
315 | embed_dim=384,
316 | depth=12,
317 | num_heads=6,
318 | mlp_ratio=4,
319 | block_fn=partial(Block, attn_class=MemEffAttention),
320 | **kwargs,
321 | )
322 | return model
323 |
324 |
325 | def vit_base(patch_size=16, **kwargs):
326 | model = DinoVisionTransformer(
327 | patch_size=patch_size,
328 | embed_dim=768,
329 | depth=12,
330 | num_heads=12,
331 | mlp_ratio=4,
332 | block_fn=partial(Block, attn_class=MemEffAttention),
333 | **kwargs,
334 | )
335 | return model
336 |
337 |
338 | def vit_large(patch_size=16, **kwargs):
339 | model = DinoVisionTransformer(
340 | patch_size=patch_size,
341 | embed_dim=1024,
342 | depth=24,
343 | num_heads=16,
344 | mlp_ratio=4,
345 | block_fn=partial(Block, attn_class=MemEffAttention),
346 | **kwargs,
347 | )
348 | return model
349 |
350 |
351 | def vit_giant2(patch_size=16, **kwargs):
352 | """
353 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
354 | """
355 | model = DinoVisionTransformer(
356 | patch_size=patch_size,
357 | embed_dim=1536,
358 | depth=40,
359 | num_heads=24,
360 | mlp_ratio=4,
361 | block_fn=partial(Block, attn_class=MemEffAttention),
362 | **kwargs,
363 | )
364 | return model
365 |
--------------------------------------------------------------------------------
/models/vision_transformer_lora.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | from functools import partial
12 | import math
13 | import logging
14 | from typing import Sequence, Tuple, Union, Callable
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.utils.checkpoint
19 | from torch.nn.init import trunc_normal_
20 |
21 | from models.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention_lora, NestedTensorBlock as Block
22 |
23 |
24 | logger = logging.getLogger("dinov2")
25 |
26 |
27 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
28 | if not depth_first and include_root:
29 | fn(module=module, name=name)
30 | for child_name, child_module in module.named_children():
31 | child_name = ".".join((name, child_name)) if name else child_name
32 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
33 | if depth_first and include_root:
34 | fn(module=module, name=name)
35 | return module
36 |
37 |
38 | class BlockChunk(nn.ModuleList):
39 | def forward(self, x):
40 | for b in self:
41 | x = b(x)
42 | return x
43 |
44 |
45 | class DinoVisionTransformer(nn.Module):
46 | def __init__(
47 | self,
48 | img_size=224,
49 | patch_size=16,
50 | in_chans=3,
51 | embed_dim=768,
52 | depth=12,
53 | num_heads=12,
54 | mlp_ratio=4.0,
55 | qkv_bias=True,
56 | ffn_bias=True,
57 | proj_bias=True,
58 | drop_path_rate=0.0,
59 | drop_path_uniform=False,
60 | init_values=None, # for layerscale: None or 0 => no layerscale
61 | embed_layer=PatchEmbed,
62 | act_layer=nn.GELU,
63 | block_fn=Block,
64 | ffn_layer="mlp",
65 | block_chunks=1,
66 | ):
67 | """
68 | Args:
69 | img_size (int, tuple): input image size
70 | patch_size (int, tuple): patch size
71 | in_chans (int): number of input channels
72 | embed_dim (int): embedding dimension
73 | depth (int): depth of transformer
74 | num_heads (int): number of attention heads
75 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
76 | qkv_bias (bool): enable bias for qkv if True
77 | proj_bias (bool): enable bias for proj in attn if True
78 | ffn_bias (bool): enable bias for ffn if True
79 | drop_path_rate (float): stochastic depth rate
80 | drop_path_uniform (bool): apply uniform drop rate across blocks
81 | weight_init (str): weight init scheme
82 | init_values (float): layer-scale init values
83 | embed_layer (nn.Module): patch embedding layer
84 | act_layer (nn.Module): MLP activation layer
85 | block_fn (nn.Module): transformer block class
86 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
87 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
88 | """
89 | super().__init__()
90 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
91 |
92 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
93 | self.num_tokens = 1
94 | self.n_blocks = depth
95 | self.num_heads = num_heads
96 | self.patch_size = patch_size
97 |
98 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
99 | num_patches = self.patch_embed.num_patches
100 |
101 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
102 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1114, embed_dim))
103 |
104 | if drop_path_uniform is True:
105 | dpr = [drop_path_rate] * depth
106 | else:
107 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
108 |
109 | if ffn_layer == "mlp":
110 | logger.info("using MLP layer as FFN")
111 | ffn_layer = Mlp
112 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
113 | logger.info("using SwiGLU layer as FFN")
114 | ffn_layer = SwiGLUFFNFused
115 | elif ffn_layer == "identity":
116 | logger.info("using Identity layer as FFN")
117 |
118 | def f(*args, **kwargs):
119 | return nn.Identity()
120 |
121 | ffn_layer = f
122 | else:
123 | raise NotImplementedError
124 |
125 | blocks_list = [
126 | block_fn(
127 | dim=embed_dim,
128 | num_heads=num_heads,
129 | mlp_ratio=mlp_ratio,
130 | qkv_bias=qkv_bias,
131 | proj_bias=proj_bias,
132 | ffn_bias=ffn_bias,
133 | drop_path=dpr[i],
134 | norm_layer=norm_layer,
135 | act_layer=act_layer,
136 | ffn_layer=ffn_layer,
137 | init_values=init_values,
138 | )
139 | for i in range(depth)
140 | ]
141 | if block_chunks > 0:
142 | self.chunked_blocks = True
143 | chunked_blocks = []
144 | chunksize = depth // block_chunks
145 | for i in range(0, depth, chunksize):
146 | # this is to keep the block index consistent if we chunk the block list
147 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
148 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
149 | else:
150 | self.chunked_blocks = False
151 | self.blocks = nn.ModuleList(blocks_list)
152 |
153 | self.norm = norm_layer(embed_dim)
154 | self.head = nn.Identity()
155 |
156 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
157 |
158 | self.init_weights()
159 |
160 | def init_weights(self):
161 | trunc_normal_(self.pos_embed, std=0.02)
162 | nn.init.normal_(self.cls_token, std=1e-6)
163 | named_apply(init_weights_vit_timm, self)
164 |
165 | def interpolate_pos_encoding(self, x, w, h):
166 | previous_dtype = x.dtype
167 | npatch = x.shape[1] - 1
168 | N = self.pos_embed.shape[1] - 1
169 | if npatch == N and w == h:
170 | return self.pos_embed
171 | pos_embed = self.pos_embed.float()
172 | class_pos_embed = pos_embed[:, 0]
173 | patch_pos_embed = pos_embed[:, 1:]
174 | dim = x.shape[-1]
175 | w0 = w // self.patch_size
176 | h0 = h // self.patch_size
177 | # we add a small number to avoid floating point error in the interpolation
178 | # see discussion at https://github.com/facebookresearch/dino/issues/8
179 | w0, h0 = w0 + 0.1, h0 + 0.1
180 |
181 | patch_pos_embed = nn.functional.interpolate(
182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
184 | mode="bicubic",
185 | )
186 |
187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
190 |
191 | def prepare_tokens_with_masks(self, x, masks=None):
192 | B, nc, w, h = x.shape
193 | x = self.patch_embed(x)
194 | if masks is not None:
195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
196 |
197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
198 | x = x + self.interpolate_pos_encoding(x, w, h)
199 |
200 | return x
201 |
202 | def forward_features_list(self, x_list, masks_list):
203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
204 |
205 | for blk in self.blocks:
206 | x = blk(x)
207 |
208 | all_x = x
209 | output = []
210 | for x, masks in zip(all_x, masks_list):
211 | x_norm = self.norm(x)
212 | output.append(
213 | {
214 | "x_norm_clstoken": x_norm[:, 0],
215 | "x_norm_patchtokens": x_norm[:, 1:],
216 | "x_prenorm": x,
217 | "masks": masks,
218 | }
219 | )
220 | return output
221 |
222 | def forward_features(self, x, masks=None):
223 | if isinstance(x, list):
224 | return self.forward_features_list(x, masks)
225 |
226 | x = self.prepare_tokens_with_masks(x, masks)
227 |
228 | count = 1
229 | x_middle = {}
230 | for blk in self.blocks:
231 | x = blk(x)
232 | if count == 3 or count == 6 or count == 9 or count == 12:
233 | x_middle[str(count)] = self.norm(x)[:, 1:]
234 | count = count + 1
235 |
236 | x_norm = self.norm(x)
237 | return {
238 | "x_norm_clstoken": x_norm[:, 0],
239 | "x_norm_patchtokens": x_norm[:, 1:],
240 | "x_prenorm": x,
241 | "masks": masks,
242 | }, x_middle
243 |
244 | def _get_intermediate_layers_not_chunked(self, x, n=1):
245 | x = self.prepare_tokens_with_masks(x)
246 | # If n is an int, take the n last blocks. If it's a list, take them
247 | output, total_block_len = [], len(self.blocks)
248 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
249 | for i, blk in enumerate(self.blocks):
250 | x = blk(x)
251 | if i in blocks_to_take:
252 | output.append(x)
253 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
254 | return output
255 |
256 | def _get_intermediate_layers_chunked(self, x, n=1):
257 | x = self.prepare_tokens_with_masks(x)
258 | output, i, total_block_len = [], 0, len(self.blocks[-1])
259 | # If n is an int, take the n last blocks. If it's a list, take them
260 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
261 | for block_chunk in self.blocks:
262 | for blk in block_chunk[i:]: # Passing the nn.Identity()
263 | x = blk(x)
264 | if i in blocks_to_take:
265 | output.append(x)
266 | i += 1
267 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
268 | return output
269 |
270 | def get_intermediate_layers(
271 | self,
272 | x: torch.Tensor,
273 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
274 | reshape: bool = False,
275 | return_class_token: bool = False,
276 | norm=True,
277 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
278 | if self.chunked_blocks:
279 | outputs = self._get_intermediate_layers_chunked(x, n)
280 | else:
281 | outputs = self._get_intermediate_layers_not_chunked(x, n)
282 | if norm:
283 | outputs = [self.norm(out) for out in outputs]
284 | class_tokens = [out[:, 0] for out in outputs]
285 | outputs = [out[:, 1:] for out in outputs]
286 | if reshape:
287 | B, _, w, h = x.shape
288 | outputs = [
289 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
290 | for out in outputs
291 | ]
292 | if return_class_token:
293 | return tuple(zip(outputs, class_tokens))
294 | return tuple(outputs)
295 |
296 | def forward(self, *args, is_training=False, **kwargs):
297 | ret = self.forward_features(*args, **kwargs)
298 | if is_training:
299 | return ret
300 | else:
301 | return self.head(ret["x_norm_clstoken"])
302 |
303 |
304 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
305 | """ViT weight initialization, original timm impl (for reproducibility)"""
306 | if isinstance(module, nn.Linear):
307 | trunc_normal_(module.weight, std=0.02)
308 | if module.bias is not None:
309 | nn.init.zeros_(module.bias)
310 |
311 |
312 | def vit_small_lora(patch_size=16, **kwargs):
313 | model = DinoVisionTransformer(
314 | patch_size=patch_size,
315 | embed_dim=384,
316 | depth=12,
317 | num_heads=6,
318 | mlp_ratio=4,
319 | block_fn=partial(Block, attn_class=MemEffAttention_lora),
320 | **kwargs,
321 | )
322 | return model
323 |
324 |
325 | def vit_base_lora(patch_size=16, **kwargs):
326 | model = DinoVisionTransformer(
327 | patch_size=patch_size,
328 | embed_dim=768,
329 | depth=12,
330 | num_heads=12,
331 | mlp_ratio=4,
332 | block_fn=partial(Block, attn_class=MemEffAttention_lora),
333 | **kwargs,
334 | )
335 | return model
336 |
337 |
338 | def vit_large(patch_size=16, **kwargs):
339 | model = DinoVisionTransformer(
340 | patch_size=patch_size,
341 | embed_dim=1024,
342 | depth=24,
343 | num_heads=16,
344 | mlp_ratio=4,
345 | block_fn=partial(Block, attn_class=MemEffAttention_lora),
346 | **kwargs,
347 | )
348 | return model
349 |
350 |
351 | def vit_giant2(patch_size=16, **kwargs):
352 | """
353 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
354 | """
355 | model = DinoVisionTransformer(
356 | patch_size=patch_size,
357 | embed_dim=1536,
358 | depth=40,
359 | num_heads=24,
360 | mlp_ratio=4,
361 | block_fn=partial(Block, attn_class=MemEffAttention_lora),
362 | **kwargs,
363 | )
364 | return model
365 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fvcore==0.1.5.post20221221
2 | loralib==0.1.2
3 | matplotlib==3.6.2
4 | numpy==1.23.4
5 | Pillow==9.2.0
6 | Pillow==10.4.0
7 | scikit_learn==1.5.1
8 | tensorboardX==2.6.2.2
9 | timm==1.0.9
10 | torch==2.4.0
11 | torchmetrics==1.4.1
12 | torchvision==0.19.0
13 | tqdm==4.66.1
14 | xformers==0.0.27.post2
15 |
--------------------------------------------------------------------------------
/run/mla_crater.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "----------Training Starge--------------"
4 | python ../demo_classification.py -net "mla" \
5 | -d "crater" \
6 | -dn "cuda" \
7 | -v "small"\
8 | -loss "wdice"\
9 | -cp "unfrozen"\
10 | -l 1e-5
11 | echo "----------Training Over----------------"
12 |
13 | echo "---------------Evaluate----------------"
14 | python ../evaluate_classification.py -net "mla" \
15 | -d "crater" \
16 | -dn "cuda" \
17 | -v "small"\
18 | -loss "wdice"\
19 | -cp "unfrozen"
20 | echo "Done"
21 |
--------------------------------------------------------------------------------
/run/mla_das.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "----------Training Starge--------------"
4 | python ../demo_classification.py -net "mla" \
5 | -d "das" \
6 | -dn "cuda" \
7 | -v "small"\
8 | -loss "wdice"\
9 | -cp "unfrozen"\
10 | -l 1e-5
11 | echo "----------Training Over----------------"
12 |
13 | echo "---------------Evaluate----------------"
14 | python ../evaluate_classification.py -net "mla" \
15 | -d "das" \
16 | -dn "cuda" \
17 | -v "small"\
18 | -loss "wdice"\
19 | -cp "unfrozen"
20 | echo "Done"
21 |
--------------------------------------------------------------------------------
/run/mla_facies.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "----------Training Starge--------------"
4 | python ../demo_classification.py -net "mla" \
5 | -d "seam" \
6 | -dn "cuda" \
7 | -v "small"\
8 | -loss "wdice"\
9 | -cp "unfrozen"\
10 | -l 1e-5
11 | echo "----------Training Over----------------"
12 |
13 | echo "---------------Evaluate----------------"
14 | python ../evaluate_classification.py -net "mla" \
15 | -d "seam" \
16 | -dn "cuda" \
17 | -v "small"\
18 | -loss "wdice"\
19 | -cp "unfrozen"
20 | echo "Done"
21 |
--------------------------------------------------------------------------------
/run/mla_fault.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "----------Training Starge--------------"
4 | python ../demo_classification.py -net "mla" \
5 | -d "seam" \
6 | -dn "cuda" \
7 | -v "small"\
8 | -loss "wdice"\
9 | -cp "lora"\
10 | -l 1e-5
11 | echo "----------Training Over----------------"
12 |
13 | echo "---------------Evaluate----------------"
14 | python ../evaluate_classification.py -net "mla" \
15 | -d "seam" \
16 | -dn "cuda" \
17 | -v "small"\
18 | -loss "wdice"\
19 | -cp "lora"
20 | echo "Done"
21 |
--------------------------------------------------------------------------------
/run/mla_salt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "----------Training Starge--------------"
4 | python ../demo_classification.py -net "mla" \
5 | -d "salt" \
6 | -dn "cuda" \
7 | -v "small"\
8 | -loss "wdice"\
9 | -cp "unfrozen"\
10 | -l 1e-5
11 | echo "----------Training Over----------------"
12 |
13 | echo "---------------Evaluate----------------"
14 | python ../evaluate_classification.py -net "mla" \
15 | -d "salt" \
16 | -dn "cuda" \
17 | -v "small"\
18 | -loss "wdice"\
19 | -cp "unfrozen"
20 | echo "Done"
21 |
--------------------------------------------------------------------------------