├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── demo_classification.py ├── evaluate_classification.py ├── loss ├── DiceLoss.py ├── FocalLoss.py ├── WeightDiceLoss.py ├── metric.py └── ssim.py ├── models ├── Dpt.py ├── __init__.py ├── adapter.py ├── dpt │ ├── __init__.py │ ├── base_model.py │ ├── blocks.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── dino_head.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── midas_net.py │ ├── models.py │ ├── transforms.py │ └── vit.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py ├── unet.py ├── vision_transformer.py └── vision_transformer_lora.py ├── requirements.txt └── run ├── mla_crater.sh ├── mla_das.sh ├── mla_facies.sh ├── mla_fault.sh └── mla_salt.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.dat 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhixiang Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌏 Cross-Domain Foundation Model Adaptation: Pioneering Computer Vision Models for Geophysical Data Analysis 2 | 3 | 4 | 🏢 [Computational Interpretation Group (CIG)](https://cig.ustc.edu.cn/main.htm) 5 | 6 | [Zhixiang Guo1](https://cig.ustc.edu.cn/guo/list.htm), 7 | [Xinming Wu1*](https://cig.ustc.edu.cn/xinming/list.htm), 8 | [Luming Liang2](https://www.microsoft.com/en-us/research/people/lulian/), 9 | [Hanlin Sheng1](https://cig.ustc.edu.cn/hanlin/list.htm), 10 | [Nuo Chen1](https://cig.ustc.edu.cn/nuo/list.htm), 11 | [Zhengfa Bi3](https://profiles.lbl.gov/416831-zhengfa-bi) 12 | 13 | School of Earth and Space Sciences, University of Science and Technology of China, Hefei, China 14 | 中国科学技术大学_64x64 15 | 16 | 17 | Microsoft Applied Sciences Group, Redmond, WA 98052, United States 18 | 19 | 20 | Lawrence Berkeley National Laboratory, 1 Cyclotron Rd, CA 94707, USA 21 | 截屏2024-07-07 13 12 39 22 | 23 | ## :mega: News 24 | :flying_saucer: The dataset, model, code, and demo are coming soon! 25 | 26 | :collision: [2025.02.23]: The paper has been accepted for publication in [[JGR: Machine Learning and Computation](https://agupubs.onlinelibrary.wiley.com/doi/pdf/10.1029/2025JH000601)] 27 | 28 | :collision: [2024.09.01]: The code has been uploaded. 29 | 30 | :collision: [2024.08.23]: The paper has been submitted to Arxiv: https://arxiv.org/pdf/2408.12396 31 | 32 | :collision: [2024.07.23]: Upload the [dataset](https://github.com/ProgrammerZXG/Cross-Domain-Foundation-Model-Adaptation/blob/master/README.md#package-dataset). 33 | 34 | :collision: [2024.07.07]: Github Repository Initialization. 35 | 36 | ## :sparkles: Introduction 37 |

38 | Workflow for adapting pre-trained foundation models to geophysics. 39 | First, we prepare geophysical training datasets (1st column), 40 | which involves collecting and processing relevant geophysical data 41 | to ensure it is suitable for adaption fine-tuning. Next, we load the pre-trained 42 | foundation model as the data feature encoder (2nd column) 43 | and fine-tune the model to make it adaptable to geophysical data. 44 | To map the encoder features to the task-specific targets, 45 | we explore suitable decoders 46 | (3rd column) for geophysical downstream adaption. Finally, the adapted model 47 | is applied to various downstream tasks within the geophysics 48 | field (4th column). 49 |

50 | 51 |
52 | 53 |
54 | 55 | 56 | ## 🚀 Quick Start 57 | 58 | ### 1. Clone the repository 59 | Our code provides demos corresponding to the data mentioned in the paper, 60 | including seismic facies, geological bodies, DAS, faults, and craters. 61 | You can run them by following the steps below: 62 | 63 | First, clone the repository to your local machine: 64 | 65 | ```bash 66 | 67 | git clone git@github.com:ProgrammerZXG/Cross-Domain-Foundation-Model-Adaptation.git 68 | cd Cross-Domain-Foundation-Model-Adaptation 69 | 70 | ``` 71 | 72 | ### 2. Install dependencies 73 | 74 | ```bash 75 | 76 | pip install -r requirements.txt 77 | 78 | ``` 79 | 80 | ### 3. Download the dataset 81 | 82 | Before running the code, you need to download the dataset. 83 | You can download the dataset in [Zenodo](https://zenodo.org/records/12798750) and put them in the `data/`. 84 | 85 | ### 4. Run the code 86 | 87 | ```bash 88 | 89 | cd run 90 | bash mla_facies.sh 91 | 92 | ``` 93 | If you choose to use `bash run/mla_facies.sh`, please be aware of the dataset path. 94 | 95 | ## :stars: Results 96 | 97 | 98 | ### Quantitative Metrics for Downstream Tasks 99 | 100 | #### Mean Intersection over Union (mIoU) 101 | 102 | | Network | Seismic Facies
Classification | Seismic Geobody
Identification | Crater
Detection | DAS Seismic
Event Detection | Deep Fault
Detection | 103 | |---------------|:------------:|:------------:|:------------:|:------------:|:------------:| 104 | | Unet | 0.5490 | 0.8636 | 0.5812 | 0.7271 | 0.6858 | 105 | | DINOv2-LINEAR | 0.6565 | 0.8965 | 0.6857 | 0.8112 | 0.6372 | 106 | | DINOv2-PUP | **0.6885** | 0.8935 | 0.6937 | 0.8487 | 0.7088 | 107 | | DINOv2-DPT | 0.6709 | 0.8912 | 0.6917 | **0.8672** | 0.7334 | 108 | | DINOv2-MLA | 0.6826 | **0.8969** | **0.6949** | 0.8591 | **0.7613** | 109 | 110 | 111 | #### Mean Pixel Accuracy (mPA) 112 | 113 | | Network | Seismic Facies
Classification | Seismic Geobody
Identification | Crater
Detection | DAS Seismic
Event Detection | Deep Fault
Detection | 114 | |---------------|:------------:|:------------:|:------------:|:------------:|:------------:| 115 | | Unet | 0.7693 | 0.9112 | 0.6265 | 0.7865 | 0.7439 | 116 | | DINOv2-LINEAR | 0.8732 | 0.9374 | 0.7481 | 0.9033 | 0.7519 | 117 | | DINOv2-PUP | **0.9102** | 0.9357 | 0.7529 | 0.9210 | 0.7793 | 118 | | DINOv2-DPT | 0.8826 | 0.9377 | 0.7462 | 0.9119 | 0.7985 | 119 | | DINOv2-MLA | 0.8975 | **0.9383** | **0.7476** |**0.9222** | **0.8195** | 120 | 121 | ## :package: Dataset 122 | All data is avalable at [Zenodo](https://zenodo.org/records/12798750). 123 | 124 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.12798750.svg)](https://doi.org/10.5281/zenodo.12798750) 125 | 126 | | Task | Data Sources | Data Size | Training
Number | Test
Number | 127 | |------------------------------|-----------------------------------------------|--------------|-----------------|-------------| 128 | | Seismic Facies Classification|
provided by [(SEAM, 2020)](https://www.aicrowd.com/challenges/seismic-facies-identification-challenge/discussion)
|
1006 × 782
|
250
|
45
| 129 | | Salt Body Identification |
provided by
[(Addison Howard et al., 2018)](https://www.kaggle.com/competitions/tgs-salt-identification-challenge)
|
224 × 224
|
3000
|
1000
| 130 | | Crater Detection |
original data provided by [CAS](https://moon.bao.ac.cn/),
labelled by authors
|
1022 × 1022
|
1000
|
199
| 131 | | DAS Seismic Event Detection |
provided by [(Biondi et al., 2023)](https://zenodo.org/records/8270895)
|
512 × 512
|
115
|
28
| 132 | | Deep Fault Detection |
original data provided
from field surveys,
labelled by authors
|
896 × 896
|
1081
|
269
| 133 | 134 | ## :bookmark: Citation 135 | 136 | If you find this work useful, please consider citing our paper: 137 | 138 | ```markdown 139 | 140 | @misc{guo2024crossdomainfoundationmodeladaptation, 141 | title={Cross-Domain Foundation Model Adaptation: Pioneering Computer Vision Models for Geophysical Data Analysis}, 142 | author={Zhixiang Guo and Xinming Wu and Luming Liang and Hanlin Sheng and Nuo Chen and Zhengfa Bi}, 143 | year={2024}, 144 | eprint={2408.12396}, 145 | archivePrefix={arXiv}, 146 | primaryClass={cs.CV}, 147 | url={https://arxiv.org/abs/2408.12396}, 148 | } 149 | ``` 150 | 151 | ## :memo: Acknowledgment 152 | This study is strongly supported by the Supercomputing 153 | Center of the University of Science and Technology of China, 154 | particularly with the provision of Nvidia 80G A100 GPUs, 155 | which are crucial for our experiments. 156 | We also thank [SEAM](https://seg.org/SEAM) for providing the seismic facies classification dataset, 157 | [TGS](https://www.kaggle.com/competitions/tgs-salt-identification-challenge) for the geobody identification dataset, 158 | [CAS](https://moon.bao.ac.cn) for the crater detection dataset, 159 | [Biondi](https://www.science.org/doi/full/10.1126/sciadv.adi9878) for the DAS seismic event detection dataset, 160 | and [CIG](https://cig.ustc.edu.cn/main.htm) for the deep fault detection dataset. 161 | 162 | ## :postbox: Contact 163 | If you have any questions about this work, 164 | please feel free to contact xinmwu@ustc.edu.cn or zxg3@mail.ustc.edu.cn. 165 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import torchvision.transforms as T 6 | 7 | class BasicDataset(Dataset): 8 | 9 | def __init__(self,patch_h,patch_w,datasetName,netType,train_mode = False): 10 | 11 | self.patch_h = patch_h 12 | self.patch_w = patch_w 13 | 14 | if netType == 'unet' or netType == 'deeplabv3plus': 15 | self.imgTrans = False 16 | else: 17 | self.imgTrans = True 18 | 19 | self.transform = T.Compose([ 20 | T.Resize((patch_h * 14, patch_w * 14)), 21 | T.ToTensor(), 22 | ]) 23 | 24 | self.dataset = datasetName 25 | 26 | if datasetName == 'seam': 27 | self.n1 = 1006 28 | self.n2 = 782 29 | # self.train_data_dir = '../data/seismicFace/train/input' 30 | # self.train_label_dir = '../data/seismicFace/train/target' 31 | # self.valid_data_dir = '../data/seismicFace/valid/input' 32 | # self.valid_label_dir = '../data/seismicFace/valid/target' 33 | self.train_data_dir = '/home/zxguo/data/seamai_1006x782/seamaiForTrain/input' 34 | self.train_label_dir = '/home/zxguo/data/seamai_1006x782/seamaiForTrain/target' 35 | self.valid_data_dir = '/home/zxguo/data/seamai_1006x782/seamaiForVal/input' 36 | self.valid_label_dir = '/home/zxguo/data/seamai_1006x782/seamaiForVal/target' 37 | elif datasetName == 'salt': 38 | self.n1 = 224 39 | self.n2 = 224 40 | self.train_data_dir = '../data/geobody/train/input' 41 | self.train_label_dir = '../data/geobody/train/target' 42 | self.valid_data_dir = '../data/geobody/valid/input' 43 | self.valid_label_dir = '../data/geobody/valid/target' 44 | elif datasetName == 'fault': 45 | self.n1 = 896 46 | self.n2 = 896 47 | self.train_data_dir = '../data/deepFault/train/image' 48 | self.train_label_dir = '../data/deepFault/train/label' 49 | self.valid_data_dir = '../data/deepFault/valid/image' 50 | self.valid_label_dir = '../data/deepFault/valid/label' 51 | elif datasetName == 'crater': 52 | self.n1 = 1022 53 | self.n2 = 1022 54 | self.train_data_dir = '../data/crater/train/image' 55 | self.train_label_dir = '../data/crater/train/label' 56 | self.valid_data_dir = '../data/crater/valid/image' 57 | self.valid_label_dir = '../data/crater/valid/label' 58 | elif datasetName == 'das': 59 | self.n1 = 512 60 | self.n2 = 512 61 | self.train_data_dir = '../data/das/train/image' 62 | self.train_label_dir = '../data/das/train/label' 63 | self.valid_data_dir = '../data/das/valid/image' 64 | self.valid_label_dir = '../data/das/valid/label' 65 | else: 66 | print("Dataset error!!") 67 | print('netType:' + netType) 68 | print('dataset:' + datasetName) 69 | print('patch_h:' + str(patch_h)) 70 | print('patch_w:' + str(patch_w)) 71 | 72 | if train_mode: 73 | self.data_dir = self.train_data_dir 74 | self.label_dir = self.train_label_dir 75 | else: 76 | self.data_dir = self.valid_data_dir 77 | self.label_dir = self.valid_label_dir 78 | 79 | self.ids = len(os.listdir(self.data_dir)) 80 | def __len__(self): 81 | return self.ids 82 | 83 | def __getitem__(self,index): 84 | 85 | dPath = self.data_dir+'/'+str(index)+'.dat' 86 | tPath = self.label_dir+'/'+str(index)+'.dat' 87 | data = np.fromfile(dPath,np.float32).reshape(self.n1,self.n2) 88 | label = np.fromfile(tPath,np.int8).reshape(self.n1,self.n2) 89 | 90 | data = np.reshape(data,(1,1,self.n1,self.n2)) 91 | data = np.concatenate([data,self.data_aug(data)],axis=0) 92 | label = np.reshape(label,(1,1,self.n1,self.n2)) 93 | label = np.concatenate([label,self.data_aug(label)],axis=0) 94 | 95 | if self.imgTrans: 96 | img_tensor = np.zeros([2,1,self.patch_h*14,self.patch_w*14],np.float32) 97 | for i in range(data.shape[0]): 98 | img = Image.fromarray(np.uint8(data[i,0])) 99 | img_tensor[i,0] = self.transform(img) 100 | data = img_tensor 101 | data = data.repeat(3,axis=1) 102 | elif not self.imgTrans: 103 | data = data/255 104 | 105 | return data,label 106 | 107 | def data_aug(self,data): 108 | b,c,h,w = data.shape 109 | data_fliplr = np.fliplr(np.squeeze(data)) 110 | return data_fliplr.reshape(b,c,h,w) 111 | 112 | if __name__ == '__main__': 113 | 114 | train_set = BasicDataset(72,56,'seam','setr1',True,True) 115 | print(train_set.__getitem__(0)[1].shape) 116 | print(len(train_set)) 117 | -------------------------------------------------------------------------------- /demo_classification.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from dataset import BasicDataset 9 | from models.adapter import dinov2_mla,dinov2_pup,dinov2_linear 10 | from models.Dpt import dinov2_dpt 11 | from models.unet import U_Net 12 | import numpy as np 13 | from tensorboardX import SummaryWriter 14 | from torchmetrics.classification import JaccardIndex 15 | from loss.FocalLoss import Focal_Loss 16 | from loss.DiceLoss import DiceLoss 17 | from loss.WeightDiceLoss import WeightedDiceLoss 18 | import random 19 | import argparse 20 | # from logger import Logger 21 | import loralib as lora 22 | 23 | random.seed(1234) 24 | np.random.seed(1234) 25 | torch.manual_seed(1234) 26 | torch.cuda.manual_seed(1234) 27 | torch.cuda.manual_seed_all(1234) 28 | 29 | 30 | def main(args,logger): 31 | dir_checkpoint = '../checkpoint/' + args.dataset + "/" + args.loss + "/" +args.netType 32 | if not os.path.exists(dir_checkpoint): 33 | os.makedirs(dir_checkpoint) 34 | 35 | if args.dataset == 'seam': 36 | args.n1, args.n2 = 1006, 782 37 | args.classes = 6 38 | args.patch_h = 72 39 | args.patch_w = 56 40 | args.batch_size = 3 41 | elif args.dataset == 'salt': 42 | args.n1, args.n2 = 224, 224 43 | args.classes = 2 44 | args.patch_h = 20 45 | args.patch_w = 20 46 | args.batch_size = 32 47 | elif args.dataset == 'crater': 48 | args.n1, args.n2 = 1022, 1022 49 | args.classes = 2 50 | args.patch_h = 73 51 | args.patch_w = 73 52 | args.batch_size = 3 53 | elif args.dataset == 'das': 54 | args.n1, args.n2 = 512, 512 55 | args.classes = 2 56 | args.patch_h = 37 57 | args.patch_w = 37 58 | args.batch_size = 6 59 | elif args.dataset == 'fault': 60 | args.n1, args.n2 = 896, 896 61 | args.classes = 2 62 | args.patch_h = 64 63 | args.patch_w = 64 64 | args.batch_size = 6 65 | 66 | if args.checkpointName in ["unfrozen","lora"]: 67 | frozen = False 68 | elif args.checkpointName == "frozen": 69 | frozen = True 70 | 71 | if args.netType == "unet": 72 | net = U_Net(1,args.classes) 73 | elif args.netType == "linear": 74 | net = dinov2_linear(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName) 75 | elif args.netType == "mla": 76 | net = dinov2_mla(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName) 77 | elif args.netType == "pup": 78 | net = dinov2_pup(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName) 79 | elif args.netType == "dpt": 80 | net = dinov2_dpt(args.classes, pretrain=args.dpt, vit_type=args.vt,frozen=frozen,finetune_method=args.checkpointName) 81 | 82 | logger.info(f'\t{args.netType} NetWork:\n' 83 | f'\t{args.classes } num classes\n' 84 | f'\t{args.dataset} dataset\n' 85 | f'\t{args.vt} vitType\n' 86 | f'\t{args.loss} loss\n') 87 | # net = torch.nn.DataParallel(net, device_ids=range(device_count)) 88 | goTrain(args, 89 | dir_checkpoint, 90 | net=net, 91 | patch_h = args.patch_h, 92 | patch_w = args.patch_w, 93 | epochs=args.epochs, 94 | batch_size= int(args.batch_size), 95 | learning_rate= args.lr, 96 | num_classes = args.classes, 97 | save_checkpoint=args.save_checkpoint 98 | ) 99 | def goTrain(args, 100 | dir_checkpoint, 101 | net, 102 | patch_h, 103 | patch_w, 104 | num_classes : int, 105 | epochs:int = 5, 106 | batch_size: int = 1, 107 | learning_rate: float = 1e-4, 108 | save_checkpoint: bool = True): 109 | 110 | net.to(device) 111 | get_parameter_number(net) 112 | 113 | # Create dataset 114 | train_set = BasicDataset(patch_h, patch_w, args.dataset,args.netType, train_mode=True) 115 | valid_set = BasicDataset(patch_h, patch_w, args.dataset,args.netType, train_mode=False) 116 | 117 | #Create data loaders 118 | train_loader= DataLoader(dataset = train_set,batch_size = batch_size, shuffle=True) 119 | valid_loader= DataLoader(dataset = valid_set,batch_size = batch_size, shuffle=False) 120 | 121 | logger.info(f'''Starting training: 122 | Epochs: {epochs} 123 | Batch size: {batch_size} 124 | Learning rate: {learning_rate} 125 | Training size: {len(train_set)} 126 | Validation size: {len(valid_set)} 127 | Checkpoints: {save_checkpoint} 128 | ''') 129 | 130 | jaccard = JaccardIndex(task='multiclass',num_classes=num_classes).to(device) 131 | # Set up the optimizer, the loss, the learning rate scheduler and the loss scaling 132 | # optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-8) 133 | # optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=0.05) 134 | optimizer = optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=0.01,betas=[0.7,0.999]) 135 | if args.al: 136 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 137 | if args.loss == "ce": 138 | criterion = nn.CrossEntropyLoss() 139 | elif args.loss == "bce": 140 | criterion = nn.BCEWithLogitsLoss() 141 | elif args.loss == "focal": 142 | criterion = Focal_Loss(args.classes,device=args.device) 143 | elif args.loss == "dice": 144 | criterion = DiceLoss(args.classes) 145 | elif args.loss == "wdice": 146 | criterion = WeightedDiceLoss(args.classes,device=args.device) 147 | elif args.loss == "bace": 148 | if args.dataset == "seam": 149 | weight = torch.tensor([1.216,0.395,3.673,0.573,14.193,1.798]).reshape(-1,1).to(args.device) 150 | criterion = nn.CrossEntropyLoss(weight=weight) 151 | 152 | #Tensorboard open 153 | writer = SummaryWriter('../Tensorboard/'+args.dataset+'/' + args.loss + '/') 154 | 155 | # Begin training 156 | train_loss = [] 157 | valid_loss=[] 158 | train_iou = [] 159 | valid_iou = [] 160 | train_pa = [] 161 | valid_pa = [] 162 | MaxTrainIoU = 0 163 | MaxValidIoU = 0 164 | MinTrainLoss = 1e7 165 | MinValidLoss = 1e7 166 | net.train() 167 | warmup_steps = 10 168 | ini_lr = learning_rate*10 169 | for epoch in range(1,epochs+1): 170 | if args.al=="True": 171 | if epoch < warmup_steps: 172 | warmup_percent_done = epoch/warmup_steps 173 | optimizer.param_groups[0]['lr'] = ini_lr * warmup_percent_done 174 | else: 175 | scheduler.step() 176 | total_train_loss = [] 177 | total_valid_loss = [] 178 | total_train_iou = [] 179 | total_valid_iou = [] 180 | total_train_pa = [] 181 | total_valid_pa = [] 182 | with tqdm(total = len(train_set),desc=f'Epoch {epoch}/{epochs}',unit = 'img') as t: 183 | for data,label in train_loader: 184 | b1,b2,c,h,w = data.shape 185 | data = data.to(device).reshape(b1*b2,c,h,w) 186 | b1,b2,c,h,w = label.shape 187 | label = label.to(device).reshape(b1*b2,h,w) 188 | optimizer.zero_grad() 189 | outputs = net(data,(args.n1,args.n2)) 190 | if args.loss == "bce": 191 | loss = criterion(outputs,label.unsqueeze(1).expand(-1, 2, -1, -1).float()) 192 | else: 193 | loss = criterion(outputs,label.long()) 194 | _, preds = torch.max(outputs, 1) 195 | iou_tmp = jaccard(preds,label.long()).detach().cpu().numpy() 196 | pa_tmp = ((preds == label).sum().item() / (b1*b2*h*w)) 197 | loss.backward() 198 | optimizer.step() 199 | t.update(batch_size) 200 | t.set_postfix(**{'train_loss': loss.item(),'iou': iou_tmp,'accuracy':pa_tmp,'lr': optimizer.param_groups[0]['lr']}) 201 | total_train_loss.append(loss.item()) 202 | total_train_iou.append(iou_tmp) 203 | total_train_pa.append(pa_tmp) 204 | train_loss.append(np.mean(total_train_loss)) 205 | train_iou.append(np.mean(total_train_iou)) 206 | train_pa.append(np.mean(total_train_pa)) 207 | logger.info(f"Epoch {epoch} - TrainSet - Loss: {train_loss[-1]}, IoU: {train_iou[-1]}, Accuracy: {train_pa[-1]}") 208 | 209 | # if save_checkpoint and epoch%5==0: 210 | # torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_epoch"+str(epoch)+"_train.pth") 211 | if train_iou[-1]>MaxTrainIoU: 212 | torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_train.pth") 213 | if args.checkpointName=="lora": 214 | torch.save(lora.lora_state_dict(net), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_train_lora.pth") 215 | MaxTrainIoU = train_iou[-1] 216 | logger.info(f'max_train_iou saved!') 217 | if train_loss[-1]MaxValidIoU: 254 | torch.save(net.state_dict(), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_valid.pth") 255 | if args.checkpointName=="lora": 256 | torch.save(lora.lora_state_dict(net), dir_checkpoint + "/"+args.checkpointName + "_" + args.vt+"_maxiou_valid_lora.pth") 257 | MaxValidIoU = valid_iou[-1] 258 | logger.info(f'max_valid_iou saved!') 259 | if valid_loss[-1]= 0) & (imgLabel < self.numClass) 38 | label = self.numClass * imgLabel[mask] + imgPredict[mask] 39 | count = np.bincount(label, minlength=self.numClass**2) 40 | confusionMatrix = count.reshape(self.numClass, self.numClass) 41 | return confusionMatrix 42 | 43 | def Frequency_Weighted_Intersection_over_Union(self): 44 | # FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)] 45 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 46 | iu = np.diag(self.confusion_matrix) / ( 47 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 48 | np.diag(self.confusion_matrix)) 49 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 50 | return FWIoU 51 | 52 | 53 | def addBatch(self, imgPredict, imgLabel): 54 | assert imgPredict.shape == imgLabel.shape 55 | self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) 56 | 57 | def reset(self): 58 | self.confusionMatrix = np.zeros((self.numClass, self.numClass)) -------------------------------------------------------------------------------- /loss/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | def _fspecial_gaussian(size, channel, sigma): 6 | coords = torch.tensor([(x - (size - 1.) / 2.) for x in range(size)]) 7 | coords = -coords ** 2 / (2. * sigma ** 2) 8 | grid = coords.view(1, -1) + coords.view(-1, 1) 9 | grid = grid.view(1, -1) 10 | grid = grid.softmax(-1) 11 | kernel = grid.view(1, 1, size, size) 12 | kernel = kernel.expand(channel, 1, size, size).contiguous() 13 | return kernel 14 | 15 | # zfbi 16 | def _fspecial_gaussian3d(size, channel, sigma): 17 | coords = torch.tensor([(x - (size - 1.) / 2.) for x in range(size)]) 18 | coords = -coords ** 2 / (2. * sigma ** 2) 19 | grid = coords.view(1, -1, 1) + coords.view(-1, 1, 1) + coords.view(1, 1, -1) 20 | grid = grid.view(1, -1) 21 | grid = grid.softmax(-1) 22 | kernel = grid.view(1, 1, size, size, size) 23 | kernel = kernel.expand(channel, 1, size, size, size).contiguous() 24 | return kernel 25 | 26 | def _ssim(output, target, max_val, k1, k2, channel, kernel): 27 | c1 = (k1 * max_val) ** 2 28 | c2 = (k2 * max_val) ** 2 29 | 30 | mu1 = F.conv2d(output, kernel, groups=channel) 31 | mu2 = F.conv2d(target, kernel, groups=channel) 32 | 33 | mu1_sq = mu1 ** 2 34 | mu2_sq = mu2 ** 2 35 | mu1_mu2 = mu1 * mu2 36 | 37 | sigma1_sq = F.conv2d(output * output, kernel, groups=channel) - mu1_sq 38 | sigma2_sq = F.conv2d(target * target, kernel, groups=channel) - mu2_sq 39 | sigma12 = F.conv2d(output * target, kernel, groups=channel) - mu1_mu2 40 | 41 | v1 = 2 * sigma12 + c2 42 | v2 = sigma1_sq + sigma2_sq + c2 43 | 44 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 45 | return ssim, v1 / v2 46 | 47 | # zfbi 48 | def _ssim3d(input, target, max_val, k1, k2, channel, kernel): 49 | c1 = (k1 * max_val) ** 2 50 | c2 = (k2 * max_val) ** 2 51 | 52 | mu1 = F.conv3d(input, kernel, groups=channel) 53 | mu2 = F.conv3d(target, kernel, groups=channel) 54 | 55 | mu1_sq = mu1 ** 2 56 | mu2_sq = mu2 ** 2 57 | mu1_mu2 = mu1 * mu2 58 | 59 | sigma1_sq = F.conv3d(input * input, kernel, groups=channel) - mu1_sq 60 | sigma2_sq = F.conv3d(target * target, kernel, groups=channel) - mu2_sq 61 | sigma12 = F.conv3d(input * target, kernel, groups=channel) - mu1_mu2 62 | 63 | v1 = 2 * sigma12 + c2 64 | v2 = sigma1_sq + sigma2_sq + c2 65 | 66 | ssim = ((2 * mu1_mu2 + c1) * v1) / ((mu1_sq + mu2_sq + c1) * v2) 67 | return ssim, v1 / v2 68 | 69 | 70 | def ssim_loss(input, target, max_val, filter_size=7, k1=0.01, k2=0.03, 71 | sigma=1.5, kernel=None, size_average=None, reduce=None, reduction='mean'): 72 | 73 | if input.size() != target.size(): 74 | raise ValueError('Expected input size ({}) to match target size ({}).' 75 | .format(input.size(0), target.size(0))) 76 | 77 | if size_average is not None or reduce is not None: 78 | reduction = _Reduction.legacy_get_string(size_average, reduce) 79 | 80 | dim = input.dim() 81 | if dim == 2: 82 | input = input.expand(1, 1, input.dim(-2), input.dim(-1)) 83 | target = target.expand(1, 1, target.dim(-2), target.dim(-1)) 84 | elif dim == 3: 85 | input = input.expand(1, input.dim(-3), input.dim(-2), input.dim(-1)) 86 | target = target.expand(1, target.dim(-3), target.dim(-2), target.dim(-1)) 87 | elif dim != 4: 88 | raise ValueError('Expected 2, 3, or 4 dimensions (got {})'.format(dim)) 89 | 90 | _, channel, _, _ = input.size() 91 | 92 | if kernel is None: 93 | kernel = _fspecial_gaussian(filter_size, channel, sigma) 94 | kernel = kernel.to(device=input.device) 95 | 96 | ret, _ = _ssim(input, target, max_val, k1, k2, channel, kernel) 97 | 98 | if reduction != 'none': 99 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 100 | return ret 101 | 102 | def ssim_loss3d(input, target, max_val, filter_size=7, k1=0.01, k2=0.03, 103 | sigma=1.5, kernel=None, size_average=None, reduce=None, reduction='mean'): 104 | 105 | if input.size() != target.size(): 106 | raise ValueError('Expected input size ({}) to match target size ({}).' 107 | .format(input.size(0), target.size(0))) 108 | 109 | if size_average is not None or reduce is not None: 110 | reduction = _Reduction.legacy_get_string(size_average, reduce) 111 | 112 | dim = input.dim() 113 | if dim == 2: 114 | input = input.expand(1, 1, 1, input.dim(-2), input.dim(-1)) 115 | target = target.expand(1, 1, 1, target.dim(-2), target.dim(-1)) 116 | elif dim == 3: 117 | input = input.expand(1, 1, input.dim(-3), input.dim(-2), input.dim(-1)) 118 | target = target.expand(1, 1, target.dim(-3), target.dim(-2), target.dim(-1)) 119 | elif dim == 4: 120 | input = input.expand(1, input.dim(-4), input.dim(-3), input.dim(-2), input.dim(-1)) 121 | target = target.expand(1, target.dim(-4), target.dim(-3), target.dim(-2), target.dim(-1)) 122 | elif dim != 5: 123 | raise ValueError('Expected 2, 3, 4, or 5 dimensions (got {})'.format(dim)) 124 | 125 | _, channel, _, _, _ = input.size() 126 | 127 | if kernel is None: 128 | kernel = _fspecial_gaussian3d(filter_size, channel, sigma) 129 | kernel = kernel.to(device=input.device) 130 | 131 | ret, _ = _ssim3d(input, target, max_val, k1, k2, channel, kernel) 132 | 133 | if reduction != 'none': 134 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 135 | return ret 136 | 137 | def ms_ssim_loss(input, target, max_val, filter_size=7, k1=0.01, k2=0.03, 138 | sigma=1.5, kernel=None, weights=None, size_average=None, reduce=None, reduction='mean'): 139 | 140 | if input.size() != target.size(): 141 | raise ValueError('Expected input size ({}) to match target size ({}).' 142 | .format(input.size(0), target.size(0))) 143 | 144 | if size_average is not None or reduce is not None: 145 | reduction = _Reduction.legacy_get_string(size_average, reduce) 146 | 147 | dim = input.dim() 148 | if dim == 2: 149 | input = input.expand(1, 1, input.shape[-2], input.shape[-1]) 150 | target = target.expand(1, 1, target.shape[-2], target.shape[-1]) 151 | elif dim == 3: 152 | input = input.expand(1, input.dim(-3), input.dim(-2), input.dim(-1)) 153 | target = target.expand(1, target.dim(-3), target.dim(-2), target.dim(-1)) 154 | elif dim != 4: 155 | raise ValueError('Expected 2, 3, or 4 dimensions (got {})'.format(dim)) 156 | 157 | _, channel, _, _ = input.size() 158 | 159 | if kernel is None: 160 | kernel = _fspecial_gaussian(filter_size, channel, sigma) 161 | kernel = kernel.to(device=input.device) 162 | 163 | if weights is None: 164 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 165 | weights = torch.tensor(weights, device=input.device) 166 | weights = weights.unsqueeze(-1).unsqueeze(-1) 167 | levels = weights.size(0) 168 | mssim = [] 169 | mcs = [] 170 | for _ in range(levels): 171 | ssim, cs = _ssim(input, target, max_val, k1, k2, channel, kernel) 172 | ssim = ssim.mean((2, 3)) 173 | cs = cs.mean((2, 3)) 174 | mssim.append(ssim) 175 | mcs.append(cs) 176 | 177 | input = F.avg_pool2d(input, (2, 2)) 178 | target = F.avg_pool2d(target, (2, 2)) 179 | 180 | mssim = torch.stack(mssim) 181 | mcs = torch.stack(mcs) 182 | # Normalize 183 | mssim = (mssim + 1) / 2 184 | mcs = (mcs + 1) / 2 185 | p1 = mcs ** weights 186 | p2 = mssim ** weights 187 | 188 | ret = torch.prod(p1[:-1], 0) * p2[-1] 189 | 190 | if reduction != 'none': 191 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 192 | return ret 193 | 194 | 195 | # zfbi 196 | def ms_ssim_loss3d(input, target, max_val, filter_size=7, k1=0.01, k2=0.03, 197 | sigma=1.5, kernel=None, weights=None, size_average=None, reduce=None, reduction='mean'): 198 | 199 | if input.size() != target.size(): 200 | raise ValueError('Expected input size ({}) to match target size ({}).' 201 | .format(input.size(0), target.size(0))) 202 | 203 | if size_average is not None or reduce is not None: 204 | reduction = _Reduction.legacy_get_string(size_average, reduce) 205 | 206 | dim = input.dim() 207 | if dim == 2: 208 | input = input.expand(1, 1, 1, input.dim(-2), input.dim(-1)) 209 | target = target.expand(1, 1, 1, target.dim(-2), target.dim(-1)) 210 | elif dim == 3: 211 | input = input.expand(1, 1, input.dim(-3), input.dim(-2), input.dim(-1)) 212 | target = target.expand(1, 1, target.dim(-3), target.dim(-2), target.dim(-1)) 213 | elif dim == 4: 214 | input = input.expand(1, input.dim(-4), input.dim(-3), input.dim(-2), input.dim(-1)) 215 | target = target.expand(1, target.dim(-4), target.dim(-3), target.dim(-2), target.dim(-1)) 216 | elif dim != 5: 217 | raise ValueError('Expected 2, 3, 4, or 5 dimensions (got {})'.format(dim)) 218 | 219 | _, channel, _, _, _ = input.size() 220 | 221 | if kernel is None: 222 | kernel = _fspecial_gaussian3d(filter_size, channel, sigma) 223 | kernel = kernel.to(device=input.device) 224 | 225 | if weights is None: 226 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 227 | weights = torch.tensor(weights, device=input.device) 228 | weights = weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 229 | levels = weights.size(0) 230 | mssim = [] 231 | mcs = [] 232 | for _ in range(levels): 233 | ssim, cs = _ssim3d(input, target, max_val, k1, k2, channel, kernel) 234 | ssim = ssim.mean((2, 3, 4)) 235 | cs = cs.mean((2, 3, 4)) 236 | mssim.append(ssim) 237 | mcs.append(cs) 238 | 239 | input = F.avg_pool3d(input, (2, 2, 2)) 240 | target = F.avg_pool3d(target, (2, 2, 2)) 241 | 242 | mssim = torch.stack(mssim) 243 | mcs = torch.stack(mcs) 244 | # Normalize 245 | mssim = (mssim + 1) / 2 246 | mcs = (mcs + 1) / 2 247 | p1 = mcs ** weights 248 | p2 = mssim ** weights 249 | 250 | ret = torch.prod(p1[:-1], 0) * p2[-1] 251 | 252 | if reduction != 'none': 253 | ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret) 254 | return ret 255 | 256 | class _Loss(torch.nn.Module): 257 | def __init__(self, size_average=None, reduce=None, reduction='mean'): 258 | super(_Loss, self).__init__() 259 | if size_average is not None or reduce is not None: 260 | self.reduction = _Reduction.legacy_get_string(size_average, reduce) 261 | else: 262 | self.reduction = reduction 263 | 264 | class SSIMLoss(_Loss): 265 | 266 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] 267 | 268 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'): 269 | super(SSIMLoss, self).__init__(size_average, reduce, reduction) 270 | self.filter_size = filter_size 271 | self.k1 = k1 272 | self.k2 = k2 273 | self.sigma = sigma 274 | self.kernel = _fspecial_gaussian(filter_size, channel, sigma) 275 | 276 | def forward(self, input, target, max_val=1.): 277 | return ssim_loss(input, target, max_val=max_val, filter_size=self.filter_size, k1=self.k1, k2=self.k2, 278 | sigma=self.sigma, reduction=self.reduction, kernel=self.kernel) 279 | 280 | class SSIMLoss3D(_Loss): 281 | 282 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] 283 | 284 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'): 285 | super(SSIMLoss3D, self).__init__(size_average, reduce, reduction) 286 | self.filter_size = filter_size 287 | self.k1 = k1 288 | self.k2 = k2 289 | self.sigma = sigma 290 | self.kernel = _fspecial_gaussian3d(filter_size, channel, sigma) 291 | 292 | def forward(self, input, target, max_val=1.): 293 | return ssim_loss3d(input, target, max_val=max_val, filter_size=self.filter_size, k1=self.k1, k2=self.k2, 294 | sigma=self.sigma, reduction=self.reduction, kernel=self.kernel) 295 | 296 | class MultiScaleSSIMLoss(_Loss): 297 | 298 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] 299 | 300 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'): 301 | super(MultiScaleSSIMLoss, self).__init__(size_average, reduce, reduction) 302 | self.filter_size = filter_size 303 | self.k1 = k1 304 | self.k2 = k2 305 | self.sigma = sigma 306 | self.kernel = _fspecial_gaussian(filter_size, channel, sigma) 307 | 308 | def forward(self, input, target, weights=[0.0448, 0.2856, 0.3001, 0.2363, 0.1333], max_val=1.): 309 | return ms_ssim_loss(input, target, max_val=max_val, k1=self.k1, k2=self.k2, sigma=self.sigma, kernel=self.kernel, 310 | weights=weights, filter_size=self.filter_size, reduction=self.reduction) 311 | # zfbi 312 | class MultiScaleSSIMLoss3D(_Loss): 313 | 314 | __constants__ = ['filter_size', 'k1', 'k2', 'sigma', 'kernel', 'reduction'] 315 | 316 | def __init__(self, channel=3, filter_size=7, k1=0.01, k2=0.03, sigma=1.5, size_average=None, reduce=None, reduction='mean'): 317 | super(MultiScaleSSIMLoss3D, self).__init__(size_average, reduce, reduction) 318 | self.filter_size = filter_size 319 | self.k1 = k1 320 | self.k2 = k2 321 | self.sigma = sigma 322 | self.kernel = _fspecial_gaussian3d(filter_size, channel, sigma) 323 | 324 | def forward(self, input, target, weights=[0.0448, 0.2856, 0.3001, 0.2363, 0.1333], max_val=1.): 325 | return ms_ssim_loss3d(input, target, max_val=max_val, k1=self.k1, k2=self.k2, sigma=self.sigma, kernel=self.kernel, 326 | weights=weights, filter_size=self.filter_size, reduction=self.reduction) 327 | 328 | class MSSIMLoss(nn.Module): 329 | def __init__(self, channel, filter_size): 330 | super(MSSIMLoss, self).__init__() 331 | self.mssim = MultiScaleSSIMLoss(channel=channel, filter_size=filter_size) 332 | def forward(self, output, target): 333 | loss = (1 - self.mssim(output, target)) 334 | return loss 335 | 336 | class NSSIMLoss(nn.Module): 337 | def __init__(self, channel, filter_size): 338 | super(NSSIMLoss, self).__init__() 339 | self.ssim = SSIMLoss(channel=channel, filter_size=filter_size) 340 | def forward(self, output, target): 341 | loss = (1 - self.ssim(output, target)) 342 | return loss -------------------------------------------------------------------------------- /models/Dpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.vision_transformer_lora import vit_small_lora,vit_base_lora 5 | from models.vision_transformer import vit_small,vit_base 6 | from models.dpt import _make_fusion_block,_make_scratch 7 | import logging 8 | import loralib as lora 9 | ######################################################################################################################## 10 | 11 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 12 | 13 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 14 | logger = logging.getLogger("dinov2") 15 | state_dict = torch.load(pretrained_weights, map_location="cpu") 16 | if checkpoint_key is not None and checkpoint_key in state_dict: 17 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 18 | state_dict = state_dict[checkpoint_key] 19 | # remove `module.` prefix 20 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 21 | # remove `backbone.` prefix induced by multicrop wrapper 22 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 23 | msg = model.load_state_dict(state_dict, strict=False) 24 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 25 | 26 | def make_dinov2_model_name(arch_name: str, patch_size: int) -> str: 27 | compact_arch_name = arch_name.replace("_", "")[:4] 28 | return f"dinov2_{compact_arch_name}{patch_size}" 29 | 30 | def make_vit_encoder(dino_pretrain="False",vit_type="small",finetune_method="unfrozen"): 31 | vit_kwargs = dict( 32 | in_chans = 3, 33 | img_size=224, 34 | patch_size=14, 35 | init_values=1.0e-05, 36 | ffn_layer="mlp", 37 | block_chunks=0, 38 | qkv_bias=True, 39 | proj_bias=True, 40 | ffn_bias=True 41 | ) 42 | if dino_pretrain == "True": 43 | model_name = make_dinov2_model_name("vit_"+vit_type, 14) 44 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" 45 | pretrained_weights = torch.hub.load_state_dict_from_url(url, map_location="cpu") 46 | if finetune_method == "unfrozen" or finetune_method == "frozen": 47 | if vit_type == "small": 48 | encoder = vit_small(**vit_kwargs) 49 | emb = 384 50 | strict = True 51 | elif vit_type == "base": 52 | encoder = vit_base(**vit_kwargs) 53 | emb = 768 54 | strict = True 55 | else: 56 | print("Error in vit_type!!!") 57 | elif finetune_method == "lora": 58 | if vit_type == "small": 59 | encoder = vit_small_lora(**vit_kwargs) 60 | emb = 384 61 | strict = False 62 | elif vit_type == "base": 63 | encoder = vit_base_lora(**vit_kwargs) 64 | emb = 768 65 | strict = False 66 | else: 67 | print("Error in vit_type!!!") 68 | if dino_pretrain == "True": 69 | encoder.load_state_dict(pretrained_weights, strict=strict) 70 | return encoder,emb 71 | 72 | class dinov2_dpt(nn.Module): 73 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"): 74 | super(dinov2_dpt,self).__init__() 75 | 76 | features = 256 77 | 78 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method) 79 | self.scratch = _make_scratch([self.emb,self.emb,self.emb,self.emb], 80 | out_shape=features) 81 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn=True) 82 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn=True) 83 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn=True) 84 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn=True) 85 | 86 | self.scratch.single_conv = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1) 87 | self.scratch.output_conv = nn.Sequential( 88 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 89 | nn.ReLU(True), 90 | nn.Conv2d(32, num_classes, kernel_size=1, stride=1, padding=0), 91 | nn.ReLU(True), 92 | nn.Identity(), 93 | ) 94 | 95 | if frozen: 96 | for param in self.encoder.parameters(): 97 | param.requires_grad = False 98 | else: 99 | if finetune_method == "unfrozen": 100 | for param in self.encoder.parameters(): 101 | param.requires_grad = True 102 | elif finetune_method == "lora": 103 | lora.mark_only_lora_as_trainable(self.encoder) 104 | 105 | def forward(self,x,size): 106 | B,_,H,W = x.shape 107 | _, x_middle = self.encoder.forward_features(x) 108 | xm = [] 109 | for k,x in x_middle.items(): 110 | x = x.view( 111 | x.size(0), 112 | int(H / 14), 113 | int(W / 14), 114 | self.emb, 115 | ) 116 | x = x.permute(0, 3, 1, 2).contiguous() 117 | xm.append(x) 118 | layer_1, layer_2, layer_3, layer_4 = xm 119 | layer_1_rn = self.scratch.layer1_rn(layer_1) 120 | layer_2_rn = self.scratch.layer2_rn(layer_2) 121 | layer_3_rn = self.scratch.layer3_rn(layer_3) 122 | layer_4_rn = self.scratch.layer4_rn(layer_4) 123 | 124 | path_4 = self.scratch.refinenet4((size[0]//16,size[1]//16), layer_4_rn) 125 | path_3 = self.scratch.refinenet3((size[0]//8,size[1]//8),path_4, layer_3_rn) 126 | path_2 = self.scratch.refinenet2((size[0]//4,size[1]//4),path_3, layer_2_rn) 127 | path_1 = self.scratch.refinenet1((size[0]//2,size[1]//2),path_2, layer_1_rn) 128 | 129 | out = self.scratch.single_conv(path_1) 130 | out = F.interpolate(out,size=size) 131 | out = self.scratch.output_conv(out) 132 | return out 133 | 134 | if __name__ == "__main__": 135 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 136 | model = dinov2_dpt(1).to(device=device) 137 | total_num = sum(p.numel() for p in model.parameters()) 138 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 139 | print('Model Total: %d'%total_num) 140 | print('Model Trainable: %d'%trainable_num) 141 | x1 = torch.Tensor(1,3,434,994).to(device=device,dtype=torch.float32) 142 | y1 = model(x1,size=(434,994)) 143 | print(x1.shape) 144 | print(y1.shape) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .vision_transformer import vit_small 8 | from .layers import * 9 | -------------------------------------------------------------------------------- /models/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.transforms as T 5 | from models.vision_transformer_lora import vit_small_lora,vit_base_lora 6 | from models.vision_transformer import vit_small,vit_base 7 | import fvcore.nn.weight_init as weight_init 8 | 9 | import logging 10 | import loralib as lora 11 | 12 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 13 | 14 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 15 | logger = logging.getLogger("dinov2") 16 | state_dict = torch.load(pretrained_weights, map_location="cpu") 17 | if checkpoint_key is not None and checkpoint_key in state_dict: 18 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 19 | state_dict = state_dict[checkpoint_key] 20 | # remove `module.` prefix 21 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 22 | # remove `backbone.` prefix induced by multicrop wrapper 23 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 24 | msg = model.load_state_dict(state_dict, strict=False) 25 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 26 | 27 | def make_dinov2_model_name(arch_name: str, patch_size: int) -> str: 28 | compact_arch_name = arch_name.replace("_", "")[:4] 29 | return f"dinov2_{compact_arch_name}{patch_size}" 30 | 31 | def make_vit_encoder(dino_pretrain="False",vit_type="small",finetune_method="unfrozen"): 32 | vit_kwargs = dict( 33 | in_chans = 3, 34 | img_size=224, 35 | patch_size=14, 36 | init_values=1.0e-05, 37 | ffn_layer="mlp", 38 | block_chunks=0, 39 | qkv_bias=True, 40 | proj_bias=True, 41 | ffn_bias=True 42 | ) 43 | if dino_pretrain == "True": 44 | model_name = make_dinov2_model_name("vit_"+vit_type, 14) 45 | url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" 46 | pretrained_weights = torch.hub.load_state_dict_from_url(url, map_location="cpu") 47 | if finetune_method == "unfrozen" or finetune_method == "frozen": 48 | if vit_type == "small": 49 | encoder = vit_small(**vit_kwargs) 50 | emb = 384 51 | strict = True 52 | elif vit_type == "base": 53 | encoder = vit_base(**vit_kwargs) 54 | emb = 768 55 | strict = True 56 | else: 57 | print("Error in vit_type!!!") 58 | elif finetune_method == "lora": 59 | if vit_type == "small": 60 | encoder = vit_small_lora(**vit_kwargs) 61 | emb = 384 62 | strict = False 63 | elif vit_type == "base": 64 | encoder = vit_base_lora(**vit_kwargs) 65 | emb = 768 66 | strict = False 67 | else: 68 | print("Error in vit_type!!!") 69 | if dino_pretrain == "True": 70 | encoder.load_state_dict(pretrained_weights, strict=strict) 71 | return encoder,emb 72 | 73 | class IntermediateSequential(nn.Sequential): 74 | def __init__(self, *args, return_intermediate=True): 75 | super().__init__(*args) 76 | self.return_intermediate = return_intermediate 77 | 78 | def forward(self, input): 79 | if not self.return_intermediate: 80 | return super().forward(input) 81 | 82 | intermediate_outputs = {} 83 | output = input 84 | for name, module in self.named_children(): 85 | output = intermediate_outputs[name] = module(output) 86 | 87 | return output, intermediate_outputs 88 | 89 | 90 | class SETR_PUP(nn.Module): 91 | def __init__(self,embedding_dim,num_classes): 92 | super(SETR_PUP,self).__init__() 93 | 94 | self.embedding_dim = embedding_dim 95 | self.num_classes = num_classes 96 | 97 | extra_in_channels = int(self.embedding_dim/4) 98 | in_channels = [ 99 | self.embedding_dim, 100 | extra_in_channels, 101 | extra_in_channels, 102 | extra_in_channels, 103 | ] 104 | out_channels = [ 105 | extra_in_channels, 106 | extra_in_channels, 107 | extra_in_channels, 108 | extra_in_channels, 109 | ] 110 | 111 | modules = [] 112 | for i, (in_channel, out_channel) in enumerate( 113 | zip(in_channels, out_channels) 114 | ): 115 | modules.append( 116 | self.conv_block(in_channel,out_channel) 117 | ) 118 | modules.append(nn.Upsample(size=(1//(2**(3-i)),1//(2**(3-i))), mode='bilinear')) 119 | 120 | modules.append( 121 | nn.Conv2d( 122 | in_channels=out_channels[-1], out_channels=self.num_classes, 123 | kernel_size=1, stride=1, 124 | padding=self._get_padding('VALID', (1, 1),), 125 | )) 126 | self.decode_net = IntermediateSequential( 127 | *modules, return_intermediate=False 128 | ) 129 | 130 | def forward(self,x,size): 131 | n1,n2 = size 132 | self.decode_net[1] = nn.Upsample(size=(n1//(2**(3)),n2//(2**(3))), mode='bilinear') 133 | self.decode_net[3] = nn.Upsample(size=(n1//(2**(2)),n2//(2**(2))), mode='bilinear') 134 | self.decode_net[5] = nn.Upsample(size=(n1//(2**(1)),n2//(2**(1))), mode='bilinear') 135 | self.decode_net[7] = nn.Upsample(size=(n1,n2), mode='bilinear') 136 | return self.decode_net(x) 137 | 138 | def conv_block(self,in_channels, out_channels): 139 | conv = nn.Sequential( 140 | nn.Conv2d( 141 | int(in_channels), int(out_channels), 3, 1, 142 | padding=self._get_padding('SAME', (3, 3),), 143 | ), 144 | nn.BatchNorm2d(int(out_channels)), 145 | nn.ReLU(inplace=True), 146 | 147 | nn.Conv2d( 148 | int(out_channels), int(out_channels), 3, 1, 149 | padding=self._get_padding('SAME', (3, 3),), 150 | ), 151 | nn.BatchNorm2d(int(out_channels)), 152 | nn.ReLU(inplace=True) 153 | ) 154 | return conv 155 | 156 | def _get_padding(self, padding_type, kernel_size): 157 | assert padding_type in ['SAME', 'VALID'] 158 | if padding_type == 'SAME': 159 | _list = [(k - 1) // 2 for k in kernel_size] 160 | return tuple(_list) 161 | return tuple(0 for _ in kernel_size) 162 | 163 | class SETR_MLA(nn.Module): 164 | def __init__(self,embedding_dim,num_classes): 165 | super(SETR_MLA,self).__init__() 166 | 167 | self.embedding_dim = embedding_dim 168 | self.num_classes = num_classes 169 | 170 | self.net1_in, self.net1_intmd, self.net1_out = self._define_agg_net() 171 | self.net2_in, self.net2_intmd, self.net2_out = self._define_agg_net() 172 | self.net3_in, self.net3_intmd, self.net3_out = self._define_agg_net() 173 | self.net4_in, self.net4_intmd, self.net4_out = self._define_agg_net() 174 | 175 | self.output_net = IntermediateSequential(return_intermediate=False) 176 | self.output_net.add_module( 177 | "conv_1", 178 | nn.Conv2d( 179 | in_channels=self.embedding_dim, out_channels=self.num_classes, 180 | kernel_size=1, stride=1, 181 | padding=self._get_padding('VALID', (1, 1),), 182 | ) 183 | ) 184 | self.output_net.add_module( 185 | "upsample_1", 186 | nn.Upsample(size = (1,1), mode='bilinear') 187 | ) 188 | 189 | def forward(self,x,size): 190 | n1,n2 = size 191 | self.output_net[-1] = nn.Upsample(size = (n1,n2), mode='bilinear') 192 | x3,x6,x9,x12 = x 193 | 194 | x12_intmd_in = self.net1_in(x12) 195 | x12_out = self.net1_out(x12_intmd_in) 196 | 197 | x9_in = self.net2_in(x9) 198 | x9_intmd_in = x9_in + x12_intmd_in 199 | x9_intmd_out = self.net2_intmd(x9_intmd_in) 200 | x9_out = self.net2_out(x9_intmd_out) 201 | 202 | x6_in = self.net3_in(x6) 203 | x6_intmd_in = x6_in + x9_intmd_in 204 | x6_intmd_out = self.net3_intmd(x6_intmd_in) 205 | x6_out = self.net3_out(x6_intmd_out) 206 | 207 | x3_in = self.net4_in(x3) 208 | x3_intmd_in = x3_in + x6_intmd_in 209 | x3_intmd_out = self.net4_intmd(x3_intmd_in) 210 | x3_out = self.net4_out(x3_intmd_out) 211 | 212 | out = torch.cat((x12_out, x9_out, x6_out, x3_out), dim=1) 213 | out = self.output_net(out) 214 | 215 | return out 216 | 217 | def conv_block(self,in_channels, out_channels): 218 | conv = nn.Sequential( 219 | nn.Conv2d( 220 | int(in_channels), int(out_channels), 3, 1, 221 | padding=self._get_padding('SAME', (3, 3),), 222 | ), 223 | nn.BatchNorm2d(int(out_channels)), 224 | nn.ReLU(inplace=True), 225 | 226 | nn.Conv2d( 227 | int(out_channels), int(out_channels), 3, 1, 228 | padding=self._get_padding('SAME', (3, 3),), 229 | ), 230 | nn.BatchNorm2d(int(out_channels)), 231 | nn.ReLU(inplace=True) 232 | ) 233 | return conv 234 | 235 | def _define_agg_net(self): 236 | model_in = IntermediateSequential(return_intermediate=False) 237 | model_in.add_module( 238 | "layer_1", 239 | self.conv_block(self.embedding_dim,int(self.embedding_dim/2)) 240 | ) 241 | 242 | model_intmd = IntermediateSequential(return_intermediate=False) 243 | model_intmd.add_module( 244 | "layer_intmd", 245 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/2)) 246 | ) 247 | 248 | model_out = IntermediateSequential(return_intermediate=False) 249 | model_out.add_module( 250 | "layer_2", 251 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/2)) 252 | ) 253 | model_out.add_module( 254 | "layer_3", 255 | self.conv_block(int(self.embedding_dim/2),int(self.embedding_dim/4)) 256 | ) 257 | model_out.add_module( 258 | "upsample", nn.Upsample(scale_factor=4, mode='bilinear') 259 | ) 260 | model_out.add_module( 261 | "layer_4", 262 | self.conv_block(int(self.embedding_dim/4),int(self.embedding_dim/4)) 263 | ) 264 | return model_in, model_intmd, model_out 265 | 266 | def _get_padding(self, padding_type, kernel_size): 267 | assert padding_type in ['SAME', 'VALID'] 268 | if padding_type == 'SAME': 269 | _list = [(k - 1) // 2 for k in kernel_size] 270 | return tuple(_list) 271 | return tuple(0 for _ in kernel_size) 272 | 273 | class dinov2_pup(nn.Module): 274 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"): 275 | super(dinov2_pup,self).__init__() 276 | 277 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method) 278 | self.decoder = SETR_PUP(self.emb, num_classes) 279 | 280 | if frozen: 281 | for param in self.encoder.parameters(): 282 | param.requires_grad = False 283 | else: 284 | if finetune_method == "unfrozen": 285 | for param in self.encoder.parameters(): 286 | param.requires_grad = True 287 | elif finetune_method == "lora": 288 | lora.mark_only_lora_as_trainable(self.encoder) 289 | 290 | def forward(self,x,size): 291 | B,_,H,W = x.shape 292 | features,_ = self.encoder.forward_features(x) 293 | fea_img = features['x_norm_patchtokens'] 294 | fea_img = fea_img.view(fea_img.size(0),int(H / 14),int(W / 14),self.emb) 295 | fea_img = fea_img.permute(0, 3, 1, 2).contiguous() 296 | out = self.decoder(fea_img,size) 297 | return out 298 | 299 | class dinov2_mla(nn.Module): 300 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"): 301 | super(dinov2_mla,self).__init__() 302 | 303 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method) 304 | self.decoder = SETR_MLA(self.emb, num_classes) 305 | if frozen: 306 | for param in self.encoder.parameters(): 307 | param.requires_grad = False 308 | else: 309 | if finetune_method == "unfrozen": 310 | for param in self.encoder.parameters(): 311 | param.requires_grad = True 312 | elif finetune_method == "lora": 313 | lora.mark_only_lora_as_trainable(self.encoder) 314 | 315 | def forward(self,x,size): 316 | B,_,H,W = x.shape 317 | _, x_middle = self.encoder.forward_features(x) 318 | xm = [] 319 | for k,x in x_middle.items(): 320 | x = x.view( 321 | x.size(0), 322 | int(H / 14), 323 | int(W / 14), 324 | self.emb, 325 | ) 326 | x = x.permute(0, 3, 1, 2).contiguous() 327 | xm.append(x) 328 | out = self.decoder(xm,size) 329 | return out 330 | 331 | class dinov2_linear(nn.Module): 332 | def __init__(self, num_classes, pretrain = True, vit_type="small",frozen=False,finetune_method="unfrozen"): 333 | super(dinov2_linear,self).__init__() 334 | 335 | self.encoder, self.emb = make_vit_encoder(pretrain,vit_type,finetune_method) 336 | self.decoder = nn.Conv2d(self.emb, num_classes, kernel_size=1) 337 | 338 | if frozen: 339 | for param in self.encoder.parameters(): 340 | param.requires_grad = False 341 | else: 342 | if finetune_method == "unfrozen": 343 | for param in self.encoder.parameters(): 344 | param.requires_grad = True 345 | elif finetune_method == "lora": 346 | lora.mark_only_lora_as_trainable(self.encoder) 347 | 348 | def forward(self,x,size): 349 | B,_,H,W = x.shape 350 | features,_ = self.encoder.forward_features(x) 351 | fea_img = features['x_norm_patchtokens'] 352 | fea_img = fea_img.view(fea_img.size(0),int(H / 14),int(W / 14),self.emb) 353 | fea_img = fea_img.permute(0, 3, 1, 2).contiguous() 354 | out = self.decoder(fea_img) 355 | out = F.interpolate(out,size=size) 356 | return out 357 | 358 | -------------------------------------------------------------------------------- /models/dpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import ( 2 | FeatureFusionBlock, 3 | FeatureFusionBlock_custom, 4 | Interpolate, 5 | _make_scratch, 6 | ) 7 | from .models import _make_fusion_block -------------------------------------------------------------------------------- /models/dpt/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /models/dpt/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.dpt.vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) # ViT-L/16 - 85.0% Top1 (backbone) 34 | elif backbone == "vitb_rn50_384": 35 | pretrained = _make_pretrained_vitb_rn50_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_vit_only=use_vit_only, 39 | use_readout=use_readout, 40 | enable_attention_hooks=enable_attention_hooks, 41 | ) 42 | scratch = _make_scratch( 43 | [256, 512, 768, 768], features, groups=groups, expand=expand 44 | ) # ViT-H/16 - 85.0% Top1 (backbone) 45 | elif backbone == "vitb16_384": 46 | pretrained = _make_pretrained_vitb16_384( 47 | use_pretrained, 48 | hooks=hooks, 49 | use_readout=use_readout, 50 | enable_attention_hooks=enable_attention_hooks, 51 | ) 52 | scratch = _make_scratch( 53 | [96, 192, 384, 768], features, groups=groups, expand=expand 54 | ) # ViT-B/16 - 84.6% Top1 (backbone) 55 | elif backbone == "resnext101_wsl": 56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 57 | scratch = _make_scratch( 58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 59 | ) # efficientnet_lite3 60 | else: 61 | print(f"Backbone '{backbone}' not implemented") 62 | assert False 63 | 64 | return pretrained, scratch 65 | 66 | 67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 68 | scratch = nn.Module() 69 | 70 | out_shape1 = out_shape 71 | out_shape2 = out_shape 72 | out_shape3 = out_shape 73 | out_shape4 = out_shape 74 | if expand == True: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | out_shape4 = out_shape * 8 79 | 80 | scratch.layer1_rn = nn.Conv2d( 81 | in_shape[0], 82 | out_shape1, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False, 87 | groups=groups, 88 | ) 89 | scratch.layer2_rn = nn.Conv2d( 90 | in_shape[1], 91 | out_shape2, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1, 95 | bias=False, 96 | groups=groups, 97 | ) 98 | scratch.layer3_rn = nn.Conv2d( 99 | in_shape[2], 100 | out_shape3, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | bias=False, 105 | groups=groups, 106 | ) 107 | scratch.layer4_rn = nn.Conv2d( 108 | in_shape[3], 109 | out_shape4, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1, 113 | bias=False, 114 | groups=groups, 115 | ) 116 | 117 | return scratch 118 | 119 | 120 | def _make_resnet_backbone(resnet): 121 | pretrained = nn.Module() 122 | pretrained.layer1 = nn.Sequential( 123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 124 | ) 125 | 126 | pretrained.layer2 = resnet.layer2 127 | pretrained.layer3 = resnet.layer3 128 | pretrained.layer4 = resnet.layer4 129 | 130 | return pretrained 131 | 132 | 133 | def _make_pretrained_resnext101_wsl(use_pretrained): 134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 135 | return _make_resnet_backbone(resnet) 136 | 137 | 138 | class Interpolate(nn.Module): 139 | """Interpolation module.""" 140 | 141 | def __init__(self, scale_factor, mode, align_corners=False): 142 | """Init. 143 | 144 | Args: 145 | scale_factor (float): scaling 146 | mode (str): interpolation mode 147 | """ 148 | super(Interpolate, self).__init__() 149 | 150 | self.interp = nn.functional.interpolate 151 | self.scale_factor = scale_factor 152 | self.mode = mode 153 | self.align_corners = align_corners 154 | 155 | def forward(self, x): 156 | """Forward pass. 157 | 158 | Args: 159 | x (tensor): input 160 | 161 | Returns: 162 | tensor: interpolated data 163 | """ 164 | 165 | x = self.interp( 166 | x, 167 | scale_factor=self.scale_factor, 168 | mode=self.mode, 169 | align_corners=self.align_corners, 170 | ) 171 | 172 | return x 173 | 174 | 175 | class ResidualConvUnit(nn.Module): 176 | """Residual convolution module.""" 177 | 178 | def __init__(self, features): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super().__init__() 185 | 186 | self.conv1 = nn.Conv2d( 187 | features, features, kernel_size=3, stride=1, padding=1, bias=True 188 | ) 189 | 190 | self.conv2 = nn.Conv2d( 191 | features, features, kernel_size=3, stride=1, padding=1, bias=True 192 | ) 193 | 194 | self.relu = nn.ReLU(inplace=True) 195 | 196 | def forward(self, x): 197 | """Forward pass. 198 | 199 | Args: 200 | x (tensor): input 201 | 202 | Returns: 203 | tensor: output 204 | """ 205 | out = self.relu(x) 206 | out = self.conv1(out) 207 | out = self.relu(out) 208 | out = self.conv2(out) 209 | 210 | return out + x 211 | 212 | 213 | class FeatureFusionBlock(nn.Module): 214 | """Feature fusion block.""" 215 | 216 | def __init__(self, features): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super(FeatureFusionBlock, self).__init__() 223 | 224 | self.resConfUnit1 = ResidualConvUnit(features) 225 | self.resConfUnit2 = ResidualConvUnit(features) 226 | 227 | def forward(self, *xs): 228 | """Forward pass. 229 | 230 | Returns: 231 | tensor: output 232 | """ 233 | output = xs[0] 234 | 235 | if len(xs) == 2: 236 | output += self.resConfUnit1(xs[1]) 237 | 238 | output = self.resConfUnit2(output) 239 | 240 | output = nn.functional.interpolate( 241 | output, scale_factor=2, mode="bilinear", align_corners=True 242 | ) 243 | 244 | return output 245 | 246 | 247 | class ResidualConvUnit_custom(nn.Module): 248 | """Residual convolution module.""" 249 | 250 | def __init__(self, features, activation, bn): 251 | """Init. 252 | 253 | Args: 254 | features (int): number of features 255 | """ 256 | super().__init__() 257 | 258 | self.bn = bn 259 | 260 | self.groups = 1 261 | 262 | self.conv1 = nn.Conv2d( 263 | features, 264 | features, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=not self.bn, 269 | groups=self.groups, 270 | ) 271 | 272 | self.conv2 = nn.Conv2d( 273 | features, 274 | features, 275 | kernel_size=3, 276 | stride=1, 277 | padding=1, 278 | bias=not self.bn, 279 | groups=self.groups, 280 | ) 281 | 282 | if self.bn == True: 283 | self.bn1 = nn.BatchNorm2d(features) 284 | self.bn2 = nn.BatchNorm2d(features) 285 | 286 | self.activation = activation 287 | 288 | self.skip_add = nn.quantized.FloatFunctional() 289 | 290 | def forward(self, x): 291 | """Forward pass. 292 | 293 | Args: 294 | x (tensor): input 295 | 296 | Returns: 297 | tensor: output 298 | """ 299 | 300 | out = self.activation(x) 301 | out = self.conv1(out) 302 | if self.bn == True: 303 | out = self.bn1(out) 304 | 305 | out = self.activation(out) 306 | out = self.conv2(out) 307 | if self.bn == True: 308 | out = self.bn2(out) 309 | 310 | if self.groups > 1: 311 | out = self.conv_merge(out) 312 | 313 | return self.skip_add.add(out, x) 314 | 315 | # return out + x 316 | 317 | 318 | class FeatureFusionBlock_custom(nn.Module): 319 | """Feature fusion block.""" 320 | 321 | def __init__( 322 | self, 323 | features, 324 | activation, 325 | deconv=False, 326 | bn=False, 327 | expand=False, 328 | align_corners=True, 329 | ): 330 | """Init. 331 | 332 | Args: 333 | features (int): number of features 334 | """ 335 | super(FeatureFusionBlock_custom, self).__init__() 336 | 337 | self.deconv = deconv 338 | self.align_corners = align_corners 339 | 340 | self.groups = 1 341 | 342 | self.expand = expand 343 | out_features = features 344 | if self.expand == True: 345 | out_features = features // 2 346 | 347 | self.out_conv = nn.Conv2d( 348 | features, 349 | out_features, 350 | kernel_size=1, 351 | stride=1, 352 | padding=0, 353 | bias=True, 354 | groups=1, 355 | ) 356 | 357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 359 | 360 | self.skip_add = nn.quantized.FloatFunctional() 361 | 362 | def forward(self, size, *xs): 363 | """Forward pass. 364 | 365 | Returns: 366 | tensor: output 367 | """ 368 | output = xs[0] 369 | 370 | if len(xs) == 2: 371 | res = self.resConfUnit1(xs[1]) 372 | res = nn.functional.interpolate( 373 | res, size=(size[0]//2,size[1]//2), mode="bilinear", align_corners=self.align_corners 374 | ) 375 | output = self.skip_add.add(output, res) 376 | # output += res 377 | 378 | output = self.resConfUnit2(output) 379 | 380 | output = nn.functional.interpolate( 381 | output, size=size, mode="bilinear", align_corners=self.align_corners 382 | ) 383 | 384 | output = self.out_conv(output) 385 | 386 | return output 387 | -------------------------------------------------------------------------------- /models/dpt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention,MemEffAttention_lora 13 | -------------------------------------------------------------------------------- /models/dpt/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | import loralib as lora 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | class Attention_lora(nn.Module): 65 | def __init__( 66 | self, 67 | dim: int, 68 | num_heads: int = 8, 69 | qkv_bias: bool = False, 70 | proj_bias: bool = True, 71 | attn_drop: float = 0.0, 72 | proj_drop: float = 0.0, 73 | ) -> None: 74 | super().__init__() 75 | self.num_heads = num_heads 76 | head_dim = dim // num_heads 77 | self.scale = head_dim**-0.5 78 | 79 | self.qkv = lora.Linear(dim, dim * 3, bias=qkv_bias, r=8) 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj = lora.Linear(dim, dim, bias=proj_bias, r=8) 82 | self.proj_drop = nn.Dropout(proj_drop) 83 | 84 | def forward(self, x: Tensor) -> Tensor: 85 | B, N, C = x.shape 86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 87 | 88 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 89 | attn = q @ k.transpose(-2, -1) 90 | 91 | attn = attn.softmax(dim=-1) 92 | attn = self.attn_drop(attn) 93 | 94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 95 | x = self.proj(x) 96 | x = self.proj_drop(x) 97 | return x 98 | 99 | class MemEffAttention(Attention): 100 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 101 | if not XFORMERS_AVAILABLE: 102 | assert attn_bias is None, "xFormers is required for nested tensors usage" 103 | return super().forward(x) 104 | 105 | B, N, C = x.shape 106 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 107 | 108 | q, k, v = unbind(qkv, 2) 109 | 110 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 111 | x = x.reshape([B, N, C]) 112 | 113 | x = self.proj(x) 114 | x = self.proj_drop(x) 115 | return x 116 | 117 | class MemEffAttention_lora(Attention_lora): 118 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 119 | if not XFORMERS_AVAILABLE: 120 | assert attn_bias is None, "xFormers is required for nested tensors usage" 121 | return super().forward(x) 122 | 123 | B, N, C = x.shape 124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 125 | 126 | q, k, v = unbind(qkv, 2) 127 | 128 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 129 | x = x.reshape([B, N, C]) 130 | 131 | x = self.proj(x) 132 | x = self.proj_drop(x) 133 | return x 134 | -------------------------------------------------------------------------------- /models/dpt/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /models/dpt/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /models/dpt/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /models/dpt/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /models/dpt/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /models/dpt/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /models/dpt/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /models/dpt/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.dpt.base_model import BaseModel 9 | from models.dpt.blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_large(BaseModel): 13 | """Network for monocular depth estimation.""" 14 | 15 | def __init__(self, path=None, features=256, non_negative=True): 16 | """Init. 17 | 18 | Args: 19 | path (str, optional): Path to saved model. Defaults to None. 20 | features (int, optional): Number of features. Defaults to 256. 21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 22 | """ 23 | print("Loading weights: ", path) 24 | 25 | super(MidasNet_large, self).__init__() 26 | 27 | use_pretrained = False if path is None else True 28 | 29 | self.pretrained, self.scratch = _make_encoder( 30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained 31 | ) 32 | 33 | self.scratch.refinenet4 = FeatureFusionBlock(features) 34 | self.scratch.refinenet3 = FeatureFusionBlock(features) 35 | self.scratch.refinenet2 = FeatureFusionBlock(features) 36 | self.scratch.refinenet1 = FeatureFusionBlock(features) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear"), 41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 44 | nn.ReLU(True) if non_negative else nn.Identity(), 45 | ) 46 | 47 | if path: 48 | self.load(path) 49 | 50 | def forward(self, x): 51 | """Forward pass. 52 | 53 | Args: 54 | x (tensor): input data (image) 55 | 56 | Returns: 57 | tensor: depth 58 | """ 59 | 60 | layer_1 = self.pretrained.layer1(x) 61 | layer_2 = self.pretrained.layer2(layer_1) 62 | layer_3 = self.pretrained.layer3(layer_2) 63 | layer_4 = self.pretrained.layer4(layer_3) 64 | 65 | layer_1_rn = self.scratch.layer1_rn(layer_1) 66 | layer_2_rn = self.scratch.layer2_rn(layer_2) 67 | layer_3_rn = self.scratch.layer3_rn(layer_3) 68 | layer_4_rn = self.scratch.layer4_rn(layer_4) 69 | 70 | path_4 = self.scratch.refinenet4(layer_4_rn) 71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 74 | 75 | out = self.scratch.output_conv(path_1) 76 | 77 | return torch.squeeze(out, dim=1) 78 | -------------------------------------------------------------------------------- /models/dpt/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.dpt.base_model import BaseModel 6 | from models.dpt.blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | enable_attention_hooks=False, 36 | ): 37 | 38 | super(DPT, self).__init__() 39 | 40 | self.channels_last = channels_last 41 | 42 | hooks = { 43 | "vitb_rn50_384": [0, 1, 8, 11], 44 | "vitb16_384": [2, 5, 8, 11], 45 | "vitl16_384": [5, 11, 17, 23], 46 | } 47 | 48 | # Instantiate backbone and reassemble blocks 49 | self.pretrained, self.scratch = _make_encoder( 50 | backbone, 51 | features, 52 | False, # Set to true of you want to train from scratch, uses ImageNet weights 53 | groups=1, 54 | expand=False, 55 | exportable=False, 56 | hooks=hooks[backbone], 57 | use_readout=readout, 58 | enable_attention_hooks=enable_attention_hooks, 59 | ) 60 | 61 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 63 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 64 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 65 | 66 | self.scratch.output_conv = head 67 | 68 | def forward(self, x): 69 | if self.channels_last == True: 70 | x.contiguous(memory_format=torch.channels_last) 71 | 72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 73 | 74 | layer_1_rn = self.scratch.layer1_rn(layer_1) 75 | layer_2_rn = self.scratch.layer2_rn(layer_2) 76 | layer_3_rn = self.scratch.layer3_rn(layer_3) 77 | layer_4_rn = self.scratch.layer4_rn(layer_4) 78 | 79 | path_4 = self.scratch.refinenet4(layer_4_rn) 80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 83 | 84 | out = self.scratch.output_conv(path_1) 85 | 86 | return out 87 | 88 | 89 | class DPTDepthModel(DPT): 90 | def __init__( 91 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs 92 | ): 93 | features = kwargs["features"] if "features" in kwargs else 256 94 | 95 | self.scale = scale 96 | self.shift = shift 97 | self.invert = invert 98 | 99 | head = nn.Sequential( 100 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 101 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(True), 104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 105 | nn.ReLU(True) if non_negative else nn.Identity(), 106 | nn.Identity(), 107 | ) 108 | 109 | super().__init__(head, **kwargs) 110 | 111 | if path is not None: 112 | self.load(path) 113 | 114 | def forward(self, x): 115 | inv_depth = super().forward(x).squeeze(dim=1) 116 | 117 | if self.invert: 118 | depth = self.scale * inv_depth + self.shift 119 | depth[depth < 1e-8] = 1e-8 120 | depth = 1.0 / depth 121 | return depth 122 | else: 123 | return inv_depth 124 | 125 | 126 | class DPTSegmentationModel(DPT): 127 | def __init__(self, num_classes, path=None, **kwargs): 128 | 129 | features = kwargs["features"] if "features" in kwargs else 256 130 | 131 | kwargs["use_bn"] = True 132 | 133 | head = nn.Sequential( 134 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 135 | nn.BatchNorm2d(features), 136 | nn.ReLU(True), 137 | nn.Dropout(0.1, False), 138 | nn.Conv2d(features, num_classes, kernel_size=1), 139 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 140 | ) 141 | 142 | super().__init__(head, **kwargs) 143 | 144 | self.auxlayer = nn.Sequential( 145 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 146 | nn.BatchNorm2d(features), 147 | nn.ReLU(True), 148 | nn.Dropout(0.1, False), 149 | nn.Conv2d(features, num_classes, kernel_size=1), 150 | ) 151 | 152 | if path is not None: 153 | self.load(path) 154 | -------------------------------------------------------------------------------- /models/dpt/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height).""" 50 | 51 | def __init__( 52 | self, 53 | width, 54 | height, 55 | resize_target=True, 56 | keep_aspect_ratio=False, 57 | ensure_multiple_of=1, 58 | resize_method="lower_bound", 59 | image_interpolation_method=cv2.INTER_AREA, 60 | ): 61 | """Init. 62 | 63 | Args: 64 | width (int): desired output width 65 | height (int): desired output height 66 | resize_target (bool, optional): 67 | True: Resize the full sample (image, mask, target). 68 | False: Resize image only. 69 | Defaults to True. 70 | keep_aspect_ratio (bool, optional): 71 | True: Keep the aspect ratio of the input sample. 72 | Output sample might not have the given width and height, and 73 | resize behaviour depends on the parameter 'resize_method'. 74 | Defaults to False. 75 | ensure_multiple_of (int, optional): 76 | Output width and height is constrained to be multiple of this parameter. 77 | Defaults to 1. 78 | resize_method (str, optional): 79 | "lower_bound": Output will be at least as large as the given size. 80 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 81 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 82 | Defaults to "lower_bound". 83 | """ 84 | self.__width = width 85 | self.__height = height 86 | 87 | self.__resize_target = resize_target 88 | self.__keep_aspect_ratio = keep_aspect_ratio 89 | self.__multiple_of = ensure_multiple_of 90 | self.__resize_method = resize_method 91 | self.__image_interpolation_method = image_interpolation_method 92 | 93 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 94 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 95 | 96 | if max_val is not None and y > max_val: 97 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 98 | 99 | if y < min_val: 100 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 101 | 102 | return y 103 | 104 | def get_size(self, width, height): 105 | # determine new height and width 106 | scale_height = self.__height / height 107 | scale_width = self.__width / width 108 | 109 | if self.__keep_aspect_ratio: 110 | if self.__resize_method == "lower_bound": 111 | # scale such that output size is lower bound 112 | if scale_width > scale_height: 113 | # fit width 114 | scale_height = scale_width 115 | else: 116 | # fit height 117 | scale_width = scale_height 118 | elif self.__resize_method == "upper_bound": 119 | # scale such that output size is upper bound 120 | if scale_width < scale_height: 121 | # fit width 122 | scale_height = scale_width 123 | else: 124 | # fit height 125 | scale_width = scale_height 126 | elif self.__resize_method == "minimal": 127 | # scale as least as possbile 128 | if abs(1 - scale_width) < abs(1 - scale_height): 129 | # fit width 130 | scale_height = scale_width 131 | else: 132 | # fit height 133 | scale_width = scale_height 134 | else: 135 | raise ValueError( 136 | f"resize_method {self.__resize_method} not implemented" 137 | ) 138 | 139 | if self.__resize_method == "lower_bound": 140 | new_height = self.constrain_to_multiple_of( 141 | scale_height * height, min_val=self.__height 142 | ) 143 | new_width = self.constrain_to_multiple_of( 144 | scale_width * width, min_val=self.__width 145 | ) 146 | elif self.__resize_method == "upper_bound": 147 | new_height = self.constrain_to_multiple_of( 148 | scale_height * height, max_val=self.__height 149 | ) 150 | new_width = self.constrain_to_multiple_of( 151 | scale_width * width, max_val=self.__width 152 | ) 153 | elif self.__resize_method == "minimal": 154 | new_height = self.constrain_to_multiple_of(scale_height * height) 155 | new_width = self.constrain_to_multiple_of(scale_width * width) 156 | else: 157 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 158 | 159 | return (new_width, new_height) 160 | 161 | def __call__(self, sample): 162 | width, height = self.get_size( 163 | sample["image"].shape[1], sample["image"].shape[0] 164 | ) 165 | 166 | # resize sample 167 | sample["image"] = cv2.resize( 168 | sample["image"], 169 | (width, height), 170 | interpolation=self.__image_interpolation_method, 171 | ) 172 | 173 | if self.__resize_target: 174 | if "disparity" in sample: 175 | sample["disparity"] = cv2.resize( 176 | sample["disparity"], 177 | (width, height), 178 | interpolation=cv2.INTER_NEAREST, 179 | ) 180 | 181 | if "depth" in sample: 182 | sample["depth"] = cv2.resize( 183 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 184 | ) 185 | 186 | sample["mask"] = cv2.resize( 187 | sample["mask"].astype(np.float32), 188 | (width, height), 189 | interpolation=cv2.INTER_NEAREST, 190 | ) 191 | sample["mask"] = sample["mask"].astype(bool) 192 | 193 | return sample 194 | 195 | 196 | class NormalizeImage(object): 197 | """Normlize image by given mean and std.""" 198 | 199 | def __init__(self, mean, std): 200 | self.__mean = mean 201 | self.__std = std 202 | 203 | def __call__(self, sample): 204 | sample["image"] = (sample["image"] - self.__mean) / self.__std 205 | 206 | return sample 207 | 208 | 209 | class PrepareForNet(object): 210 | """Prepare sample for usage as network input.""" 211 | 212 | def __init__(self): 213 | pass 214 | 215 | def __call__(self, sample): 216 | image = np.transpose(sample["image"], (2, 0, 1)) 217 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 218 | 219 | if "mask" in sample: 220 | sample["mask"] = sample["mask"].astype(np.float32) 221 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 222 | 223 | if "disparity" in sample: 224 | disparity = sample["disparity"].astype(np.float32) 225 | sample["disparity"] = np.ascontiguousarray(disparity) 226 | 227 | if "depth" in sample: 228 | depth = sample["depth"].astype(np.float32) 229 | sample["depth"] = np.ascontiguousarray(depth) 230 | 231 | return sample 232 | -------------------------------------------------------------------------------- /models/dpt/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | 8 | 9 | activations = {} 10 | 11 | 12 | def get_activation(name): 13 | def hook(model, input, output): 14 | activations[name] = output 15 | 16 | return hook 17 | 18 | 19 | attention = {} 20 | 21 | 22 | def get_attention(name): 23 | def hook(module, input, output): 24 | x = input[0] 25 | B, N, C = x.shape 26 | qkv = ( 27 | module.qkv(x) 28 | .reshape(B, N, 3, module.num_heads, C // module.num_heads) 29 | .permute(2, 0, 3, 1, 4) 30 | ) 31 | q, k, v = ( 32 | qkv[0], 33 | qkv[1], 34 | qkv[2], 35 | ) # make torchscript happy (cannot use tensor as tuple) 36 | 37 | attn = (q @ k.transpose(-2, -1)) * module.scale 38 | 39 | attn = attn.softmax(dim=-1) # [:,:,1,1:] 40 | attention[name] = attn 41 | 42 | return hook 43 | 44 | 45 | def get_mean_attention_map(attn, token, shape): 46 | attn = attn[:, :, token, 1:] 47 | attn = attn.unflatten(2, torch.Size([shape[2] // 16, shape[3] // 16])).float() 48 | attn = torch.nn.functional.interpolate( 49 | attn, size=shape[2:], mode="bicubic", align_corners=False 50 | ).squeeze(0) 51 | 52 | all_attn = torch.mean(attn, 0) 53 | 54 | return all_attn 55 | 56 | 57 | class Slice(nn.Module): 58 | def __init__(self, start_index=1): 59 | super(Slice, self).__init__() 60 | self.start_index = start_index 61 | 62 | def forward(self, x): 63 | return x[:, self.start_index :] 64 | 65 | 66 | class AddReadout(nn.Module): 67 | def __init__(self, start_index=1): 68 | super(AddReadout, self).__init__() 69 | self.start_index = start_index 70 | 71 | def forward(self, x): 72 | if self.start_index == 2: 73 | readout = (x[:, 0] + x[:, 1]) / 2 74 | else: 75 | readout = x[:, 0] 76 | return x[:, self.start_index :] + readout.unsqueeze(1) 77 | 78 | 79 | class ProjectReadout(nn.Module): 80 | def __init__(self, in_features, start_index=1): 81 | super(ProjectReadout, self).__init__() 82 | self.start_index = start_index 83 | 84 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 85 | 86 | def forward(self, x): 87 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 88 | features = torch.cat((x[:, self.start_index :], readout), -1) 89 | 90 | return self.project(features) 91 | 92 | 93 | class Transpose(nn.Module): 94 | def __init__(self, dim0, dim1): 95 | super(Transpose, self).__init__() 96 | self.dim0 = dim0 97 | self.dim1 = dim1 98 | 99 | def forward(self, x): 100 | x = x.transpose(self.dim0, self.dim1) 101 | return x 102 | 103 | 104 | def forward_vit(pretrained, x): 105 | b, c, h, w = x.shape 106 | 107 | glob = pretrained.model.forward_flex(x) 108 | 109 | layer_1 = pretrained.activations["1"] 110 | layer_2 = pretrained.activations["2"] 111 | layer_3 = pretrained.activations["3"] 112 | layer_4 = pretrained.activations["4"] 113 | 114 | layer_1 = pretrained.act_postprocess1[0:2](layer_1) 115 | layer_2 = pretrained.act_postprocess2[0:2](layer_2) 116 | layer_3 = pretrained.act_postprocess3[0:2](layer_3) 117 | layer_4 = pretrained.act_postprocess4[0:2](layer_4) 118 | 119 | unflatten = nn.Sequential( 120 | nn.Unflatten( 121 | 2, 122 | torch.Size( 123 | [ 124 | h // pretrained.model.patch_size[1], 125 | w // pretrained.model.patch_size[0], 126 | ] 127 | ), 128 | ) 129 | ) 130 | 131 | if layer_1.ndim == 3: 132 | layer_1 = unflatten(layer_1) 133 | if layer_2.ndim == 3: 134 | layer_2 = unflatten(layer_2) 135 | if layer_3.ndim == 3: 136 | layer_3 = unflatten(layer_3) 137 | if layer_4.ndim == 3: 138 | layer_4 = unflatten(layer_4) 139 | 140 | layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) 141 | layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) 142 | layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) 143 | layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) 144 | 145 | return layer_1, layer_2, layer_3, layer_4 146 | 147 | 148 | def _resize_pos_embed(self, posemb, gs_h, gs_w): 149 | posemb_tok, posemb_grid = ( 150 | posemb[:, : self.start_index], 151 | posemb[0, self.start_index :], 152 | ) 153 | 154 | gs_old = int(math.sqrt(len(posemb_grid))) 155 | 156 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 157 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") 158 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 159 | 160 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 161 | 162 | return posemb 163 | 164 | 165 | def forward_flex(self, x): 166 | b, c, h, w = x.shape 167 | 168 | pos_embed = self._resize_pos_embed( 169 | self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] 170 | ) 171 | 172 | B = x.shape[0] 173 | 174 | if hasattr(self.patch_embed, "backbone"): 175 | x = self.patch_embed.backbone(x) 176 | if isinstance(x, (list, tuple)): 177 | x = x[-1] # last feature if backbone outputs list/tuple of features 178 | 179 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) 180 | 181 | if getattr(self, "dist_token", None) is not None: 182 | cls_tokens = self.cls_token.expand( 183 | B, -1, -1 184 | ) # stole cls_tokens impl from Phil Wang, thanks 185 | dist_token = self.dist_token.expand(B, -1, -1) 186 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 187 | else: 188 | cls_tokens = self.cls_token.expand( 189 | B, -1, -1 190 | ) # stole cls_tokens impl from Phil Wang, thanks 191 | x = torch.cat((cls_tokens, x), dim=1) 192 | 193 | x = x + pos_embed 194 | x = self.pos_drop(x) 195 | 196 | for blk in self.blocks: 197 | x = blk(x) 198 | 199 | x = self.norm(x) 200 | 201 | return x 202 | 203 | 204 | def get_readout_oper(vit_features, features, use_readout, start_index=1): 205 | if use_readout == "ignore": 206 | readout_oper = [Slice(start_index)] * len(features) 207 | elif use_readout == "add": 208 | readout_oper = [AddReadout(start_index)] * len(features) 209 | elif use_readout == "project": 210 | readout_oper = [ 211 | ProjectReadout(vit_features, start_index) for out_feat in features 212 | ] 213 | else: 214 | assert ( 215 | False 216 | ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" 217 | 218 | return readout_oper 219 | 220 | 221 | def _make_vit_b16_backbone( 222 | model, 223 | features=[96, 192, 384, 768], 224 | size=[384, 384], 225 | hooks=[2, 5, 8, 11], 226 | vit_features=768, 227 | use_readout="ignore", 228 | start_index=1, 229 | enable_attention_hooks=False, 230 | ): 231 | pretrained = nn.Module() 232 | 233 | pretrained.model = model 234 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 235 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 236 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 237 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 238 | 239 | pretrained.activations = activations 240 | 241 | if enable_attention_hooks: 242 | pretrained.model.blocks[hooks[0]].attn.register_forward_hook( 243 | get_attention("attn_1") 244 | ) 245 | pretrained.model.blocks[hooks[1]].attn.register_forward_hook( 246 | get_attention("attn_2") 247 | ) 248 | pretrained.model.blocks[hooks[2]].attn.register_forward_hook( 249 | get_attention("attn_3") 250 | ) 251 | pretrained.model.blocks[hooks[3]].attn.register_forward_hook( 252 | get_attention("attn_4") 253 | ) 254 | pretrained.attention = attention 255 | 256 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 257 | 258 | # 32, 48, 136, 384 259 | pretrained.act_postprocess1 = nn.Sequential( 260 | readout_oper[0], 261 | Transpose(1, 2), 262 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 263 | nn.Conv2d( 264 | in_channels=vit_features, 265 | out_channels=features[0], 266 | kernel_size=1, 267 | stride=1, 268 | padding=0, 269 | ), 270 | nn.ConvTranspose2d( 271 | in_channels=features[0], 272 | out_channels=features[0], 273 | kernel_size=4, 274 | stride=4, 275 | padding=0, 276 | bias=True, 277 | dilation=1, 278 | groups=1, 279 | ), 280 | ) 281 | 282 | pretrained.act_postprocess2 = nn.Sequential( 283 | readout_oper[1], 284 | Transpose(1, 2), 285 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 286 | nn.Conv2d( 287 | in_channels=vit_features, 288 | out_channels=features[1], 289 | kernel_size=1, 290 | stride=1, 291 | padding=0, 292 | ), 293 | nn.ConvTranspose2d( 294 | in_channels=features[1], 295 | out_channels=features[1], 296 | kernel_size=2, 297 | stride=2, 298 | padding=0, 299 | bias=True, 300 | dilation=1, 301 | groups=1, 302 | ), 303 | ) 304 | 305 | pretrained.act_postprocess3 = nn.Sequential( 306 | readout_oper[2], 307 | Transpose(1, 2), 308 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 309 | nn.Conv2d( 310 | in_channels=vit_features, 311 | out_channels=features[2], 312 | kernel_size=1, 313 | stride=1, 314 | padding=0, 315 | ), 316 | ) 317 | 318 | pretrained.act_postprocess4 = nn.Sequential( 319 | readout_oper[3], 320 | Transpose(1, 2), 321 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 322 | nn.Conv2d( 323 | in_channels=vit_features, 324 | out_channels=features[3], 325 | kernel_size=1, 326 | stride=1, 327 | padding=0, 328 | ), 329 | nn.Conv2d( 330 | in_channels=features[3], 331 | out_channels=features[3], 332 | kernel_size=3, 333 | stride=2, 334 | padding=1, 335 | ), 336 | ) 337 | 338 | pretrained.model.start_index = start_index 339 | pretrained.model.patch_size = [16, 16] 340 | 341 | # We inject this function into the VisionTransformer instances so that 342 | # we can use it with interpolated position embeddings without modifying the library source. 343 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 344 | pretrained.model._resize_pos_embed = types.MethodType( 345 | _resize_pos_embed, pretrained.model 346 | ) 347 | 348 | return pretrained 349 | 350 | 351 | def _make_vit_b_rn50_backbone( 352 | model, 353 | features=[256, 512, 768, 768], 354 | size=[384, 384], 355 | hooks=[0, 1, 8, 11], 356 | vit_features=768, 357 | use_vit_only=False, 358 | use_readout="ignore", 359 | start_index=1, 360 | enable_attention_hooks=False, 361 | ): 362 | pretrained = nn.Module() 363 | 364 | pretrained.model = model 365 | 366 | if use_vit_only == True: 367 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 368 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 369 | else: 370 | pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( 371 | get_activation("1") 372 | ) 373 | pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( 374 | get_activation("2") 375 | ) 376 | 377 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 378 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) 379 | 380 | if enable_attention_hooks: 381 | pretrained.model.blocks[2].attn.register_forward_hook(get_attention("attn_1")) 382 | pretrained.model.blocks[5].attn.register_forward_hook(get_attention("attn_2")) 383 | pretrained.model.blocks[8].attn.register_forward_hook(get_attention("attn_3")) 384 | pretrained.model.blocks[11].attn.register_forward_hook(get_attention("attn_4")) 385 | pretrained.attention = attention 386 | 387 | pretrained.activations = activations 388 | 389 | readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) 390 | 391 | if use_vit_only == True: 392 | pretrained.act_postprocess1 = nn.Sequential( 393 | readout_oper[0], 394 | Transpose(1, 2), 395 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 396 | nn.Conv2d( 397 | in_channels=vit_features, 398 | out_channels=features[0], 399 | kernel_size=1, 400 | stride=1, 401 | padding=0, 402 | ), 403 | nn.ConvTranspose2d( 404 | in_channels=features[0], 405 | out_channels=features[0], 406 | kernel_size=4, 407 | stride=4, 408 | padding=0, 409 | bias=True, 410 | dilation=1, 411 | groups=1, 412 | ), 413 | ) 414 | 415 | pretrained.act_postprocess2 = nn.Sequential( 416 | readout_oper[1], 417 | Transpose(1, 2), 418 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 419 | nn.Conv2d( 420 | in_channels=vit_features, 421 | out_channels=features[1], 422 | kernel_size=1, 423 | stride=1, 424 | padding=0, 425 | ), 426 | nn.ConvTranspose2d( 427 | in_channels=features[1], 428 | out_channels=features[1], 429 | kernel_size=2, 430 | stride=2, 431 | padding=0, 432 | bias=True, 433 | dilation=1, 434 | groups=1, 435 | ), 436 | ) 437 | else: 438 | pretrained.act_postprocess1 = nn.Sequential( 439 | nn.Identity(), nn.Identity(), nn.Identity() 440 | ) 441 | pretrained.act_postprocess2 = nn.Sequential( 442 | nn.Identity(), nn.Identity(), nn.Identity() 443 | ) 444 | 445 | pretrained.act_postprocess3 = nn.Sequential( 446 | readout_oper[2], 447 | Transpose(1, 2), 448 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 449 | nn.Conv2d( 450 | in_channels=vit_features, 451 | out_channels=features[2], 452 | kernel_size=1, 453 | stride=1, 454 | padding=0, 455 | ), 456 | ) 457 | 458 | pretrained.act_postprocess4 = nn.Sequential( 459 | readout_oper[3], 460 | Transpose(1, 2), 461 | nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), 462 | nn.Conv2d( 463 | in_channels=vit_features, 464 | out_channels=features[3], 465 | kernel_size=1, 466 | stride=1, 467 | padding=0, 468 | ), 469 | nn.Conv2d( 470 | in_channels=features[3], 471 | out_channels=features[3], 472 | kernel_size=3, 473 | stride=2, 474 | padding=1, 475 | ), 476 | ) 477 | 478 | pretrained.model.start_index = start_index 479 | pretrained.model.patch_size = [16, 16] 480 | 481 | # We inject this function into the VisionTransformer instances so that 482 | # we can use it with interpolated position embeddings without modifying the library source. 483 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) 484 | 485 | # We inject this function into the VisionTransformer instances so that 486 | # we can use it with interpolated position embeddings without modifying the library source. 487 | pretrained.model._resize_pos_embed = types.MethodType( 488 | _resize_pos_embed, pretrained.model 489 | ) 490 | 491 | return pretrained 492 | 493 | 494 | def _make_pretrained_vitb_rn50_384( 495 | pretrained, 496 | use_readout="ignore", 497 | hooks=None, 498 | use_vit_only=False, 499 | enable_attention_hooks=False, 500 | ): 501 | model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) 502 | 503 | hooks = [0, 1, 8, 11] if hooks == None else hooks 504 | return _make_vit_b_rn50_backbone( 505 | model, 506 | features=[256, 512, 768, 768], 507 | size=[384, 384], 508 | hooks=hooks, 509 | use_vit_only=use_vit_only, 510 | use_readout=use_readout, 511 | enable_attention_hooks=enable_attention_hooks, 512 | ) 513 | 514 | 515 | def _make_pretrained_vitl16_384( 516 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 517 | ): 518 | model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) 519 | 520 | hooks = [5, 11, 17, 23] if hooks == None else hooks 521 | return _make_vit_b16_backbone( 522 | model, 523 | features=[256, 512, 1024, 1024], 524 | hooks=hooks, 525 | vit_features=1024, 526 | use_readout=use_readout, 527 | enable_attention_hooks=enable_attention_hooks, 528 | ) 529 | 530 | 531 | def _make_pretrained_vitb16_384( 532 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 533 | ): 534 | model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) 535 | 536 | hooks = [2, 5, 8, 11] if hooks == None else hooks 537 | return _make_vit_b16_backbone( 538 | model, 539 | features=[96, 192, 384, 768], 540 | hooks=hooks, 541 | use_readout=use_readout, 542 | enable_attention_hooks=enable_attention_hooks, 543 | ) 544 | 545 | 546 | def _make_pretrained_deitb16_384( 547 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 548 | ): 549 | model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) 550 | 551 | hooks = [2, 5, 8, 11] if hooks == None else hooks 552 | return _make_vit_b16_backbone( 553 | model, 554 | features=[96, 192, 384, 768], 555 | hooks=hooks, 556 | use_readout=use_readout, 557 | enable_attention_hooks=enable_attention_hooks, 558 | ) 559 | 560 | 561 | def _make_pretrained_deitb16_distil_384( 562 | pretrained, use_readout="ignore", hooks=None, enable_attention_hooks=False 563 | ): 564 | model = timm.create_model( 565 | "vit_deit_base_distilled_patch16_384", pretrained=pretrained 566 | ) 567 | 568 | hooks = [2, 5, 8, 11] if hooks == None else hooks 569 | return _make_vit_b16_backbone( 570 | model, 571 | features=[96, 192, 384, 768], 572 | hooks=hooks, 573 | use_readout=use_readout, 574 | start_index=2, 575 | enable_attention_hooks=enable_attention_hooks, 576 | ) 577 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention,MemEffAttention_lora 13 | -------------------------------------------------------------------------------- /models/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | import loralib as lora 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | class Attention_lora(nn.Module): 65 | def __init__( 66 | self, 67 | dim: int, 68 | num_heads: int = 8, 69 | qkv_bias: bool = False, 70 | proj_bias: bool = True, 71 | attn_drop: float = 0.0, 72 | proj_drop: float = 0.0, 73 | ) -> None: 74 | super().__init__() 75 | self.num_heads = num_heads 76 | head_dim = dim // num_heads 77 | self.scale = head_dim**-0.5 78 | 79 | self.qkv = lora.Linear(dim, dim * 3, bias=qkv_bias, r=8) 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj = lora.Linear(dim, dim, bias=proj_bias, r=8) 82 | self.proj_drop = nn.Dropout(proj_drop) 83 | 84 | def forward(self, x: Tensor) -> Tensor: 85 | B, N, C = x.shape 86 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 87 | 88 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 89 | attn = q @ k.transpose(-2, -1) 90 | 91 | attn = attn.softmax(dim=-1) 92 | attn = self.attn_drop(attn) 93 | 94 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 95 | x = self.proj(x) 96 | x = self.proj_drop(x) 97 | return x 98 | 99 | class MemEffAttention(Attention): 100 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 101 | if not XFORMERS_AVAILABLE: 102 | assert attn_bias is None, "xFormers is required for nested tensors usage" 103 | return super().forward(x) 104 | 105 | B, N, C = x.shape 106 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 107 | 108 | q, k, v = unbind(qkv, 2) 109 | 110 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 111 | x = x.reshape([B, N, C]) 112 | 113 | x = self.proj(x) 114 | x = self.proj_drop(x) 115 | return x 116 | 117 | class MemEffAttention_lora(Attention_lora): 118 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 119 | if not XFORMERS_AVAILABLE: 120 | assert attn_bias is None, "xFormers is required for nested tensors usage" 121 | return super().forward(x) 122 | 123 | B, N, C = x.shape 124 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 125 | 126 | q, k, v = unbind(qkv, 2) 127 | 128 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 129 | x = x.reshape([B, N, C]) 130 | 131 | x = self.proj(x) 132 | x = self.proj_drop(x) 133 | return x 134 | -------------------------------------------------------------------------------- /models/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /models/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /models/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /models/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /models/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /models/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class DoubleConv(nn.Module): 7 | """(convolution => [BN] => ReLU) * 2""" 8 | 9 | def __init__(self, in_channels, out_channels, mid_channels=None): 10 | super().__init__() 11 | if not mid_channels: 12 | mid_channels = out_channels 13 | self.double_conv = nn.Sequential( 14 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 15 | nn.BatchNorm2d(mid_channels), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 18 | nn.BatchNorm2d(out_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | return self.double_conv(x) 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | class Up(nn.Module): 39 | """Upscaling then double conv""" 40 | 41 | def __init__(self, in_channels, out_channels, size, bilinear=True): 42 | super().__init__() 43 | 44 | # if bilinear, use the normal convolutions to reduce the number of channels 45 | if bilinear: 46 | self.up = nn.Upsample(size=size, mode='bilinear', align_corners=True) 47 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 48 | else: 49 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 50 | self.conv = DoubleConv(in_channels, out_channels) 51 | 52 | def forward(self, x1, x2): 53 | x1 = self.up(x1) 54 | # input is CHW 55 | diffY = x2.size()[2] - x1.size()[2] 56 | diffX = x2.size()[3] - x1.size()[3] 57 | 58 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 59 | diffY // 2, diffY - diffY // 2]) 60 | # if you have padding issues, see 61 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 62 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 63 | x = torch.cat([x2, x1], dim=1) 64 | return self.conv(x) 65 | 66 | class Upconv(nn.Module): 67 | """Upscaling then double conv""" 68 | 69 | def __init__(self, in_channels, out_channels): 70 | super().__init__() 71 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 72 | 73 | def forward(self, x1, x2): 74 | # input is CHW 75 | diffY = x2.size()[2] - x1.size()[2] 76 | diffX = x2.size()[3] - x1.size()[3] 77 | 78 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 79 | diffY // 2, diffY - diffY // 2]) 80 | # if you have padding issues, see 81 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 82 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 83 | x = torch.cat([x2, x1], dim=1) 84 | return self.conv(x) 85 | 86 | class OutConv(nn.Module): 87 | def __init__(self, in_channels, out_channels): 88 | super(OutConv, self).__init__() 89 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 90 | 91 | def forward(self, x): 92 | return self.conv(x) 93 | 94 | class U_Net(nn.Module): 95 | def __init__(self, n_channels, n_classes, bilinear=True, outScale=False): 96 | super(U_Net, self).__init__() 97 | self.n_channels = n_channels 98 | self.n_classes = n_classes 99 | self.bilinear = bilinear 100 | 101 | self.inc = (DoubleConv(n_channels, 32)) 102 | self.down1 = (Down(32, 64)) 103 | self.down2 = (Down(64, 128)) 104 | self.down3 = (Down(128, 256)) 105 | factor = 2 if bilinear else 1 106 | self.down4 = (Down(256, 512 // factor)) 107 | self.up1 = (Upconv(512, 256 // factor)) 108 | self.up2 = (Upconv(256, 128 // factor)) 109 | self.up3 = (Upconv(128, 64 // factor)) 110 | self.up4 = (Upconv(64, 32)) 111 | self.outc = (OutConv(32, n_classes)) 112 | self.outScale = outScale 113 | 114 | def forward(self, x, size): 115 | n1,n2 = size 116 | x1 = self.inc(x) 117 | x2 = self.down1(x1) 118 | x3 = self.down2(x2) 119 | x4 = self.down3(x3) 120 | x5 = self.down4(x4) 121 | x = self.up((n1//8,n2//8),x5) 122 | x = self.up1(x5, x4) 123 | x = self.up((n1//4,n2//4),x) 124 | x = self.up2(x, x3) 125 | x = self.up((n1//2,n2//2),x) 126 | x = self.up3(x, x2) 127 | x = self.up((n1 ,n2 ),x) 128 | x = self.up4(x, x1) 129 | logits = self.outc(x) 130 | if self.outScale: 131 | logits = F.relu(logits)+7.1 132 | return logits 133 | 134 | def up(self, size, x): 135 | up = nn.Upsample(size=size, mode='bilinear', align_corners=True) 136 | return up(x) 137 | 138 | if __name__ == "__main__": 139 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 140 | model = U_Net(1,6,224,224) 141 | model.eval() 142 | model.cuda() 143 | imgs_tensor = torch.zeros(5, 1, 224, 224).cuda() 144 | 145 | out = model(imgs_tensor) 146 | print(out.shape) -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | from functools import partial 12 | import math 13 | import logging 14 | from typing import Sequence, Tuple, Union, Callable 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.utils.checkpoint 19 | from torch.nn.init import trunc_normal_ 20 | 21 | from models.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 28 | if not depth_first and include_root: 29 | fn(module=module, name=name) 30 | for child_name, child_module in module.named_children(): 31 | child_name = ".".join((name, child_name)) if name else child_name 32 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 33 | if depth_first and include_root: 34 | fn(module=module, name=name) 35 | return module 36 | 37 | 38 | class BlockChunk(nn.ModuleList): 39 | def forward(self, x): 40 | for b in self: 41 | x = b(x) 42 | return x 43 | 44 | 45 | class DinoVisionTransformer(nn.Module): 46 | def __init__( 47 | self, 48 | img_size=224, 49 | patch_size=16, 50 | in_chans=3, 51 | embed_dim=768, 52 | depth=12, 53 | num_heads=12, 54 | mlp_ratio=4.0, 55 | qkv_bias=True, 56 | ffn_bias=True, 57 | proj_bias=True, 58 | drop_path_rate=0.0, 59 | drop_path_uniform=False, 60 | init_values=None, # for layerscale: None or 0 => no layerscale 61 | embed_layer=PatchEmbed, 62 | act_layer=nn.GELU, 63 | block_fn=Block, 64 | ffn_layer="mlp", 65 | block_chunks=1 66 | ): 67 | """ 68 | Args: 69 | img_size (int, tuple): input image size 70 | patch_size (int, tuple): patch size 71 | in_chans (int): number of input channels 72 | embed_dim (int): embedding dimension 73 | depth (int): depth of transformer 74 | num_heads (int): number of attention heads 75 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 76 | qkv_bias (bool): enable bias for qkv if True 77 | proj_bias (bool): enable bias for proj in attn if True 78 | ffn_bias (bool): enable bias for ffn if True 79 | drop_path_rate (float): stochastic depth rate 80 | drop_path_uniform (bool): apply uniform drop rate across blocks 81 | weight_init (str): weight init scheme 82 | init_values (float): layer-scale init values 83 | embed_layer (nn.Module): patch embedding layer 84 | act_layer (nn.Module): MLP activation layer 85 | block_fn (nn.Module): transformer block class 86 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 87 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 88 | """ 89 | super().__init__() 90 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 91 | 92 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 93 | self.num_tokens = 1 94 | self.n_blocks = depth 95 | self.num_heads = num_heads 96 | self.patch_size = patch_size 97 | 98 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 99 | num_patches = self.patch_embed.num_patches 100 | 101 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 102 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1114, embed_dim)) 103 | 104 | if drop_path_uniform is True: 105 | dpr = [drop_path_rate] * depth 106 | else: 107 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 108 | 109 | if ffn_layer == "mlp": 110 | logger.info("using MLP layer as FFN") 111 | ffn_layer = Mlp 112 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 113 | logger.info("using SwiGLU layer as FFN") 114 | ffn_layer = SwiGLUFFNFused 115 | elif ffn_layer == "identity": 116 | logger.info("using Identity layer as FFN") 117 | 118 | def f(*args, **kwargs): 119 | return nn.Identity() 120 | 121 | ffn_layer = f 122 | else: 123 | raise NotImplementedError 124 | 125 | blocks_list = [ 126 | block_fn( 127 | dim=embed_dim, 128 | num_heads=num_heads, 129 | mlp_ratio=mlp_ratio, 130 | qkv_bias=qkv_bias, 131 | proj_bias=proj_bias, 132 | ffn_bias=ffn_bias, 133 | drop_path=dpr[i], 134 | norm_layer=norm_layer, 135 | act_layer=act_layer, 136 | ffn_layer=ffn_layer, 137 | init_values=init_values, 138 | ) 139 | for i in range(depth) 140 | ] 141 | if block_chunks > 0: 142 | self.chunked_blocks = True 143 | chunked_blocks = [] 144 | chunksize = depth // block_chunks 145 | for i in range(0, depth, chunksize): 146 | # this is to keep the block index consistent if we chunk the block list 147 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 148 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 149 | else: 150 | self.chunked_blocks = False 151 | self.blocks = nn.ModuleList(blocks_list) 152 | 153 | self.norm = norm_layer(embed_dim) 154 | self.head = nn.Identity() 155 | 156 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 157 | 158 | self.init_weights() 159 | 160 | def init_weights(self): 161 | trunc_normal_(self.pos_embed, std=0.02) 162 | nn.init.normal_(self.cls_token, std=1e-6) 163 | named_apply(init_weights_vit_timm, self) 164 | 165 | def interpolate_pos_encoding(self, x, w, h): 166 | previous_dtype = x.dtype 167 | npatch = x.shape[1] - 1 168 | N = self.pos_embed.shape[1] - 1 169 | if npatch == N and w == h: 170 | return self.pos_embed 171 | pos_embed = self.pos_embed.float() 172 | class_pos_embed = pos_embed[:, 0] 173 | patch_pos_embed = pos_embed[:, 1:] 174 | dim = x.shape[-1] 175 | w0 = w // self.patch_size 176 | h0 = h // self.patch_size 177 | # we add a small number to avoid floating point error in the interpolation 178 | # see discussion at https://github.com/facebookresearch/dino/issues/8 179 | w0, h0 = w0 + 0.1, h0 + 0.1 180 | 181 | patch_pos_embed = nn.functional.interpolate( 182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 184 | mode="bicubic", 185 | ) 186 | 187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 190 | 191 | def prepare_tokens_with_masks(self, x, masks=None): 192 | B, nc, w, h = x.shape 193 | x = self.patch_embed(x) 194 | if masks is not None: 195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 196 | 197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 198 | x = x + self.interpolate_pos_encoding(x, w, h) 199 | 200 | return x 201 | 202 | def forward_features_list(self, x_list, masks_list): 203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 204 | 205 | for blk in self.blocks: 206 | x = blk(x) 207 | 208 | all_x = x 209 | output = [] 210 | for x, masks in zip(all_x, masks_list): 211 | x_norm = self.norm(x) 212 | output.append( 213 | { 214 | "x_norm_clstoken": x_norm[:, 0], 215 | "x_norm_patchtokens": x_norm[:, 1:], 216 | "x_prenorm": x, 217 | "masks": masks, 218 | } 219 | ) 220 | return output 221 | 222 | def forward_features(self, x, masks=None): 223 | if isinstance(x, list): 224 | return self.forward_features_list(x, masks) 225 | 226 | x = self.prepare_tokens_with_masks(x, masks) 227 | 228 | count = 1 229 | x_middle = {} 230 | for blk in self.blocks: 231 | x = blk(x) 232 | if count == 3 or count == 6 or count == 9 or count == 12: 233 | x_middle[str(count)] = self.norm(x)[:, 1:] 234 | count = count + 1 235 | 236 | x_norm = self.norm(x) 237 | return { 238 | "x_norm_clstoken": x_norm[:, 0], 239 | "x_norm_patchtokens": x_norm[:, 1:], 240 | "x_prenorm": x, 241 | "masks": masks, 242 | }, x_middle 243 | 244 | def _get_intermediate_layers_not_chunked(self, x, n=1): 245 | x = self.prepare_tokens_with_masks(x) 246 | # If n is an int, take the n last blocks. If it's a list, take them 247 | output, total_block_len = [], len(self.blocks) 248 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 249 | for i, blk in enumerate(self.blocks): 250 | x = blk(x) 251 | if i in blocks_to_take: 252 | output.append(x) 253 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 254 | return output 255 | 256 | def _get_intermediate_layers_chunked(self, x, n=1): 257 | x = self.prepare_tokens_with_masks(x) 258 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 259 | # If n is an int, take the n last blocks. If it's a list, take them 260 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 261 | for block_chunk in self.blocks: 262 | for blk in block_chunk[i:]: # Passing the nn.Identity() 263 | x = blk(x) 264 | if i in blocks_to_take: 265 | output.append(x) 266 | i += 1 267 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 268 | return output 269 | 270 | def get_intermediate_layers( 271 | self, 272 | x: torch.Tensor, 273 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 274 | reshape: bool = False, 275 | return_class_token: bool = False, 276 | norm=True, 277 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 278 | if self.chunked_blocks: 279 | outputs = self._get_intermediate_layers_chunked(x, n) 280 | else: 281 | outputs = self._get_intermediate_layers_not_chunked(x, n) 282 | if norm: 283 | outputs = [self.norm(out) for out in outputs] 284 | class_tokens = [out[:, 0] for out in outputs] 285 | outputs = [out[:, 1:] for out in outputs] 286 | if reshape: 287 | B, _, w, h = x.shape 288 | outputs = [ 289 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 290 | for out in outputs 291 | ] 292 | if return_class_token: 293 | return tuple(zip(outputs, class_tokens)) 294 | return tuple(outputs) 295 | 296 | def forward(self, *args, is_training=False, **kwargs): 297 | ret = self.forward_features(*args, **kwargs) 298 | if is_training: 299 | return ret 300 | else: 301 | return self.head(ret["x_norm_clstoken"]) 302 | 303 | 304 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 305 | """ViT weight initialization, original timm impl (for reproducibility)""" 306 | if isinstance(module, nn.Linear): 307 | trunc_normal_(module.weight, std=0.02) 308 | if module.bias is not None: 309 | nn.init.zeros_(module.bias) 310 | 311 | 312 | def vit_small(patch_size=16, **kwargs): 313 | model = DinoVisionTransformer( 314 | patch_size=patch_size, 315 | embed_dim=384, 316 | depth=12, 317 | num_heads=6, 318 | mlp_ratio=4, 319 | block_fn=partial(Block, attn_class=MemEffAttention), 320 | **kwargs, 321 | ) 322 | return model 323 | 324 | 325 | def vit_base(patch_size=16, **kwargs): 326 | model = DinoVisionTransformer( 327 | patch_size=patch_size, 328 | embed_dim=768, 329 | depth=12, 330 | num_heads=12, 331 | mlp_ratio=4, 332 | block_fn=partial(Block, attn_class=MemEffAttention), 333 | **kwargs, 334 | ) 335 | return model 336 | 337 | 338 | def vit_large(patch_size=16, **kwargs): 339 | model = DinoVisionTransformer( 340 | patch_size=patch_size, 341 | embed_dim=1024, 342 | depth=24, 343 | num_heads=16, 344 | mlp_ratio=4, 345 | block_fn=partial(Block, attn_class=MemEffAttention), 346 | **kwargs, 347 | ) 348 | return model 349 | 350 | 351 | def vit_giant2(patch_size=16, **kwargs): 352 | """ 353 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 354 | """ 355 | model = DinoVisionTransformer( 356 | patch_size=patch_size, 357 | embed_dim=1536, 358 | depth=40, 359 | num_heads=24, 360 | mlp_ratio=4, 361 | block_fn=partial(Block, attn_class=MemEffAttention), 362 | **kwargs, 363 | ) 364 | return model 365 | -------------------------------------------------------------------------------- /models/vision_transformer_lora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | from functools import partial 12 | import math 13 | import logging 14 | from typing import Sequence, Tuple, Union, Callable 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.utils.checkpoint 19 | from torch.nn.init import trunc_normal_ 20 | 21 | from models.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention_lora, NestedTensorBlock as Block 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 28 | if not depth_first and include_root: 29 | fn(module=module, name=name) 30 | for child_name, child_module in module.named_children(): 31 | child_name = ".".join((name, child_name)) if name else child_name 32 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 33 | if depth_first and include_root: 34 | fn(module=module, name=name) 35 | return module 36 | 37 | 38 | class BlockChunk(nn.ModuleList): 39 | def forward(self, x): 40 | for b in self: 41 | x = b(x) 42 | return x 43 | 44 | 45 | class DinoVisionTransformer(nn.Module): 46 | def __init__( 47 | self, 48 | img_size=224, 49 | patch_size=16, 50 | in_chans=3, 51 | embed_dim=768, 52 | depth=12, 53 | num_heads=12, 54 | mlp_ratio=4.0, 55 | qkv_bias=True, 56 | ffn_bias=True, 57 | proj_bias=True, 58 | drop_path_rate=0.0, 59 | drop_path_uniform=False, 60 | init_values=None, # for layerscale: None or 0 => no layerscale 61 | embed_layer=PatchEmbed, 62 | act_layer=nn.GELU, 63 | block_fn=Block, 64 | ffn_layer="mlp", 65 | block_chunks=1, 66 | ): 67 | """ 68 | Args: 69 | img_size (int, tuple): input image size 70 | patch_size (int, tuple): patch size 71 | in_chans (int): number of input channels 72 | embed_dim (int): embedding dimension 73 | depth (int): depth of transformer 74 | num_heads (int): number of attention heads 75 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 76 | qkv_bias (bool): enable bias for qkv if True 77 | proj_bias (bool): enable bias for proj in attn if True 78 | ffn_bias (bool): enable bias for ffn if True 79 | drop_path_rate (float): stochastic depth rate 80 | drop_path_uniform (bool): apply uniform drop rate across blocks 81 | weight_init (str): weight init scheme 82 | init_values (float): layer-scale init values 83 | embed_layer (nn.Module): patch embedding layer 84 | act_layer (nn.Module): MLP activation layer 85 | block_fn (nn.Module): transformer block class 86 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 87 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 88 | """ 89 | super().__init__() 90 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 91 | 92 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 93 | self.num_tokens = 1 94 | self.n_blocks = depth 95 | self.num_heads = num_heads 96 | self.patch_size = patch_size 97 | 98 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 99 | num_patches = self.patch_embed.num_patches 100 | 101 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 102 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1114, embed_dim)) 103 | 104 | if drop_path_uniform is True: 105 | dpr = [drop_path_rate] * depth 106 | else: 107 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 108 | 109 | if ffn_layer == "mlp": 110 | logger.info("using MLP layer as FFN") 111 | ffn_layer = Mlp 112 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 113 | logger.info("using SwiGLU layer as FFN") 114 | ffn_layer = SwiGLUFFNFused 115 | elif ffn_layer == "identity": 116 | logger.info("using Identity layer as FFN") 117 | 118 | def f(*args, **kwargs): 119 | return nn.Identity() 120 | 121 | ffn_layer = f 122 | else: 123 | raise NotImplementedError 124 | 125 | blocks_list = [ 126 | block_fn( 127 | dim=embed_dim, 128 | num_heads=num_heads, 129 | mlp_ratio=mlp_ratio, 130 | qkv_bias=qkv_bias, 131 | proj_bias=proj_bias, 132 | ffn_bias=ffn_bias, 133 | drop_path=dpr[i], 134 | norm_layer=norm_layer, 135 | act_layer=act_layer, 136 | ffn_layer=ffn_layer, 137 | init_values=init_values, 138 | ) 139 | for i in range(depth) 140 | ] 141 | if block_chunks > 0: 142 | self.chunked_blocks = True 143 | chunked_blocks = [] 144 | chunksize = depth // block_chunks 145 | for i in range(0, depth, chunksize): 146 | # this is to keep the block index consistent if we chunk the block list 147 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 148 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 149 | else: 150 | self.chunked_blocks = False 151 | self.blocks = nn.ModuleList(blocks_list) 152 | 153 | self.norm = norm_layer(embed_dim) 154 | self.head = nn.Identity() 155 | 156 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 157 | 158 | self.init_weights() 159 | 160 | def init_weights(self): 161 | trunc_normal_(self.pos_embed, std=0.02) 162 | nn.init.normal_(self.cls_token, std=1e-6) 163 | named_apply(init_weights_vit_timm, self) 164 | 165 | def interpolate_pos_encoding(self, x, w, h): 166 | previous_dtype = x.dtype 167 | npatch = x.shape[1] - 1 168 | N = self.pos_embed.shape[1] - 1 169 | if npatch == N and w == h: 170 | return self.pos_embed 171 | pos_embed = self.pos_embed.float() 172 | class_pos_embed = pos_embed[:, 0] 173 | patch_pos_embed = pos_embed[:, 1:] 174 | dim = x.shape[-1] 175 | w0 = w // self.patch_size 176 | h0 = h // self.patch_size 177 | # we add a small number to avoid floating point error in the interpolation 178 | # see discussion at https://github.com/facebookresearch/dino/issues/8 179 | w0, h0 = w0 + 0.1, h0 + 0.1 180 | 181 | patch_pos_embed = nn.functional.interpolate( 182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 184 | mode="bicubic", 185 | ) 186 | 187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 190 | 191 | def prepare_tokens_with_masks(self, x, masks=None): 192 | B, nc, w, h = x.shape 193 | x = self.patch_embed(x) 194 | if masks is not None: 195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 196 | 197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 198 | x = x + self.interpolate_pos_encoding(x, w, h) 199 | 200 | return x 201 | 202 | def forward_features_list(self, x_list, masks_list): 203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 204 | 205 | for blk in self.blocks: 206 | x = blk(x) 207 | 208 | all_x = x 209 | output = [] 210 | for x, masks in zip(all_x, masks_list): 211 | x_norm = self.norm(x) 212 | output.append( 213 | { 214 | "x_norm_clstoken": x_norm[:, 0], 215 | "x_norm_patchtokens": x_norm[:, 1:], 216 | "x_prenorm": x, 217 | "masks": masks, 218 | } 219 | ) 220 | return output 221 | 222 | def forward_features(self, x, masks=None): 223 | if isinstance(x, list): 224 | return self.forward_features_list(x, masks) 225 | 226 | x = self.prepare_tokens_with_masks(x, masks) 227 | 228 | count = 1 229 | x_middle = {} 230 | for blk in self.blocks: 231 | x = blk(x) 232 | if count == 3 or count == 6 or count == 9 or count == 12: 233 | x_middle[str(count)] = self.norm(x)[:, 1:] 234 | count = count + 1 235 | 236 | x_norm = self.norm(x) 237 | return { 238 | "x_norm_clstoken": x_norm[:, 0], 239 | "x_norm_patchtokens": x_norm[:, 1:], 240 | "x_prenorm": x, 241 | "masks": masks, 242 | }, x_middle 243 | 244 | def _get_intermediate_layers_not_chunked(self, x, n=1): 245 | x = self.prepare_tokens_with_masks(x) 246 | # If n is an int, take the n last blocks. If it's a list, take them 247 | output, total_block_len = [], len(self.blocks) 248 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 249 | for i, blk in enumerate(self.blocks): 250 | x = blk(x) 251 | if i in blocks_to_take: 252 | output.append(x) 253 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 254 | return output 255 | 256 | def _get_intermediate_layers_chunked(self, x, n=1): 257 | x = self.prepare_tokens_with_masks(x) 258 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 259 | # If n is an int, take the n last blocks. If it's a list, take them 260 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 261 | for block_chunk in self.blocks: 262 | for blk in block_chunk[i:]: # Passing the nn.Identity() 263 | x = blk(x) 264 | if i in blocks_to_take: 265 | output.append(x) 266 | i += 1 267 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 268 | return output 269 | 270 | def get_intermediate_layers( 271 | self, 272 | x: torch.Tensor, 273 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 274 | reshape: bool = False, 275 | return_class_token: bool = False, 276 | norm=True, 277 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 278 | if self.chunked_blocks: 279 | outputs = self._get_intermediate_layers_chunked(x, n) 280 | else: 281 | outputs = self._get_intermediate_layers_not_chunked(x, n) 282 | if norm: 283 | outputs = [self.norm(out) for out in outputs] 284 | class_tokens = [out[:, 0] for out in outputs] 285 | outputs = [out[:, 1:] for out in outputs] 286 | if reshape: 287 | B, _, w, h = x.shape 288 | outputs = [ 289 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 290 | for out in outputs 291 | ] 292 | if return_class_token: 293 | return tuple(zip(outputs, class_tokens)) 294 | return tuple(outputs) 295 | 296 | def forward(self, *args, is_training=False, **kwargs): 297 | ret = self.forward_features(*args, **kwargs) 298 | if is_training: 299 | return ret 300 | else: 301 | return self.head(ret["x_norm_clstoken"]) 302 | 303 | 304 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 305 | """ViT weight initialization, original timm impl (for reproducibility)""" 306 | if isinstance(module, nn.Linear): 307 | trunc_normal_(module.weight, std=0.02) 308 | if module.bias is not None: 309 | nn.init.zeros_(module.bias) 310 | 311 | 312 | def vit_small_lora(patch_size=16, **kwargs): 313 | model = DinoVisionTransformer( 314 | patch_size=patch_size, 315 | embed_dim=384, 316 | depth=12, 317 | num_heads=6, 318 | mlp_ratio=4, 319 | block_fn=partial(Block, attn_class=MemEffAttention_lora), 320 | **kwargs, 321 | ) 322 | return model 323 | 324 | 325 | def vit_base_lora(patch_size=16, **kwargs): 326 | model = DinoVisionTransformer( 327 | patch_size=patch_size, 328 | embed_dim=768, 329 | depth=12, 330 | num_heads=12, 331 | mlp_ratio=4, 332 | block_fn=partial(Block, attn_class=MemEffAttention_lora), 333 | **kwargs, 334 | ) 335 | return model 336 | 337 | 338 | def vit_large(patch_size=16, **kwargs): 339 | model = DinoVisionTransformer( 340 | patch_size=patch_size, 341 | embed_dim=1024, 342 | depth=24, 343 | num_heads=16, 344 | mlp_ratio=4, 345 | block_fn=partial(Block, attn_class=MemEffAttention_lora), 346 | **kwargs, 347 | ) 348 | return model 349 | 350 | 351 | def vit_giant2(patch_size=16, **kwargs): 352 | """ 353 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 354 | """ 355 | model = DinoVisionTransformer( 356 | patch_size=patch_size, 357 | embed_dim=1536, 358 | depth=40, 359 | num_heads=24, 360 | mlp_ratio=4, 361 | block_fn=partial(Block, attn_class=MemEffAttention_lora), 362 | **kwargs, 363 | ) 364 | return model 365 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | loralib==0.1.2 3 | matplotlib==3.6.2 4 | numpy==1.23.4 5 | Pillow==9.2.0 6 | Pillow==10.4.0 7 | scikit_learn==1.5.1 8 | tensorboardX==2.6.2.2 9 | timm==1.0.9 10 | torch==2.4.0 11 | torchmetrics==1.4.1 12 | torchvision==0.19.0 13 | tqdm==4.66.1 14 | xformers==0.0.27.post2 15 | -------------------------------------------------------------------------------- /run/mla_crater.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "----------Training Starge--------------" 4 | python ../demo_classification.py -net "mla" \ 5 | -d "crater" \ 6 | -dn "cuda" \ 7 | -v "small"\ 8 | -loss "wdice"\ 9 | -cp "unfrozen"\ 10 | -l 1e-5 11 | echo "----------Training Over----------------" 12 | 13 | echo "---------------Evaluate----------------" 14 | python ../evaluate_classification.py -net "mla" \ 15 | -d "crater" \ 16 | -dn "cuda" \ 17 | -v "small"\ 18 | -loss "wdice"\ 19 | -cp "unfrozen" 20 | echo "Done" 21 | -------------------------------------------------------------------------------- /run/mla_das.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "----------Training Starge--------------" 4 | python ../demo_classification.py -net "mla" \ 5 | -d "das" \ 6 | -dn "cuda" \ 7 | -v "small"\ 8 | -loss "wdice"\ 9 | -cp "unfrozen"\ 10 | -l 1e-5 11 | echo "----------Training Over----------------" 12 | 13 | echo "---------------Evaluate----------------" 14 | python ../evaluate_classification.py -net "mla" \ 15 | -d "das" \ 16 | -dn "cuda" \ 17 | -v "small"\ 18 | -loss "wdice"\ 19 | -cp "unfrozen" 20 | echo "Done" 21 | -------------------------------------------------------------------------------- /run/mla_facies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "----------Training Starge--------------" 4 | python ../demo_classification.py -net "mla" \ 5 | -d "seam" \ 6 | -dn "cuda" \ 7 | -v "small"\ 8 | -loss "wdice"\ 9 | -cp "unfrozen"\ 10 | -l 1e-5 11 | echo "----------Training Over----------------" 12 | 13 | echo "---------------Evaluate----------------" 14 | python ../evaluate_classification.py -net "mla" \ 15 | -d "seam" \ 16 | -dn "cuda" \ 17 | -v "small"\ 18 | -loss "wdice"\ 19 | -cp "unfrozen" 20 | echo "Done" 21 | -------------------------------------------------------------------------------- /run/mla_fault.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "----------Training Starge--------------" 4 | python ../demo_classification.py -net "mla" \ 5 | -d "seam" \ 6 | -dn "cuda" \ 7 | -v "small"\ 8 | -loss "wdice"\ 9 | -cp "lora"\ 10 | -l 1e-5 11 | echo "----------Training Over----------------" 12 | 13 | echo "---------------Evaluate----------------" 14 | python ../evaluate_classification.py -net "mla" \ 15 | -d "seam" \ 16 | -dn "cuda" \ 17 | -v "small"\ 18 | -loss "wdice"\ 19 | -cp "lora" 20 | echo "Done" 21 | -------------------------------------------------------------------------------- /run/mla_salt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "----------Training Starge--------------" 4 | python ../demo_classification.py -net "mla" \ 5 | -d "salt" \ 6 | -dn "cuda" \ 7 | -v "small"\ 8 | -loss "wdice"\ 9 | -cp "unfrozen"\ 10 | -l 1e-5 11 | echo "----------Training Over----------------" 12 | 13 | echo "---------------Evaluate----------------" 14 | python ../evaluate_classification.py -net "mla" \ 15 | -d "salt" \ 16 | -dn "cuda" \ 17 | -v "small"\ 18 | -loss "wdice"\ 19 | -cp "unfrozen" 20 | echo "Done" 21 | --------------------------------------------------------------------------------