├── LICENSE ├── README.md ├── asset └── images │ ├── depth_results.png │ ├── normal_results.png │ └── teasor.png ├── config ├── dataset │ ├── data_diode_all.yaml │ ├── data_eth3d.yaml │ ├── data_hypersim_train.yaml │ ├── data_hypersim_val.yaml │ ├── data_kitti_eigen_test.yaml │ ├── data_kitti_val.yaml │ ├── data_nyu_test.yaml │ ├── data_nyu_train.yaml │ ├── data_scannet_val.yaml │ ├── data_vkitti_train.yaml │ ├── data_vkitti_val.yaml │ ├── dataset_train.yaml │ ├── dataset_train_hypr.yaml │ ├── dataset_val.yaml │ └── dataset_vis.yaml ├── logging.yaml ├── train_merge.yaml └── wandb.yaml ├── data ├── demo_1.png ├── val_01.png └── val_02.png ├── data_split ├── diode │ ├── diode_val_all_filename_list.txt │ ├── diode_val_indoor_filename_list.txt │ └── diode_val_outdoor_filename_list.txt ├── eth3d │ └── eth3d_filename_list.txt ├── hypersim │ ├── filename_list_test_filtered.txt │ ├── filename_list_train_filtered.txt │ ├── filename_list_val_filtered.txt │ ├── filename_list_val_filtered_small_80.txt │ ├── hypersim_caption.json │ └── selected_vis_sample.txt ├── kitti │ ├── eigen_test_files_with_gt.txt │ ├── eigen_val_from_train_800.txt │ └── eigen_val_from_train_sub_100.txt ├── nyu │ └── labeled │ │ ├── filename_list_test.txt │ │ ├── filename_list_train.txt │ │ ├── filename_list_train_small_100.txt │ │ └── nyu_test_caption.json ├── scannet │ └── scannet_val_sampled_list_800_1.txt └── vkitti │ ├── vkitti_caption.json │ ├── vkitti_train.txt │ └── vkitti_val.txt ├── inference_merge_base_depth.py ├── inference_merge_large_depth.py ├── merge ├── __init__.py ├── pipeline │ ├── embeddings.py │ ├── layers.py │ ├── merge_transformer.py │ ├── merge_transformer_flux.py │ ├── pipeline_merge.py │ ├── pipeline_merge_flux.py │ ├── transformer_attentions.py │ ├── transformer_blocks.py │ └── util │ │ ├── batchsize.py │ │ ├── ensemble.py │ │ └── image_util.py ├── train_merge_base_depth.py └── train_merge_large_depth.py ├── requirements.txt ├── src ├── __init__.py ├── dataset │ ├── __init__.py │ ├── base_depth_dataset.py │ ├── diode_dataset.py │ ├── eth3d_dataset.py │ ├── hypersim_dataset.py │ ├── kitti_dataset.py │ ├── mixed_sampler.py │ ├── nyu_dataset.py │ ├── scannet_dataset.py │ ├── tartanair_dataset.py │ └── vkitti_dataset.py └── util │ ├── alignment.py │ ├── config_util.py │ ├── data_loader.py │ ├── depth_transform.py │ ├── logging_util.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── metric.py │ ├── multi_res_noise.py │ ├── seeding.py │ └── slurm_util.py └── train_scripts ├── train_merge_b_depth.sh └── train_merge_l_depth.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

More Than Generation: Unifying Generation and Depth Estimation via Text-to-Image Diffusion Models

3 | 4 | Hongkai Lin, 5 | Dingkang Liang, 6 | Mingyang Du, 7 | Xin Zhou, 8 | Xiang Bai 9 | 10 | Huazhong University of Science & Technology 11 | 12 | ($\dagger$) Corresponding author. 13 | 14 | [![Paper](https://img.shields.io/badge/Arxiv-2510.23574-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2510.23574) 15 | [![Website](https://img.shields.io/badge/Homepage-project-orange.svg?logo=googlehome)](https://h-embodvis.github.io/MERGE) 16 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/h-embodvis/MERGE/main/LICENSE) 17 | 18 |
19 | 20 | ![MERGE_teasor.](asset/images/teasor.png) 21 | We present MERGE, a simple unified diffusion model for image generation and depth estimation. Its core lies in leveraging streamlined converters and rich visual prior stored in generative image models. Our model, derived from fixed generative image models and fine-tuned pluggable converters with synthetic data, expands powerful zero-shot depth estimation capability. 22 | 23 | --- 24 | ## 📢 **News** 25 | - **[21/Oct/2025]** The training and inference code is now available! 26 | - **[18/Sep/2025]** MERGE is accepted to **NeurIPS 2025**! 🥳🥳🥳 27 | 28 | --- 29 | ## 🛠️ Setup 30 | This installation was tested on: Ubuntu 20.04 LTS, Python 3.9.21, CUDA 11.8, NVIDIA H20-80GB. 31 | 32 | 1. Clone the repository (requires git): 33 | ``` 34 | git clone https://github.com/HongkLin/MERGE 35 | cd MERGE 36 | ``` 37 | 38 | 2. Install dependencies (requires conda): 39 | ``` 40 | conda create -n merge python=3.9.21 -y 41 | conda activate merge 42 | conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia 43 | pip install -r requirements.txt 44 | ``` 45 | --- 46 | ## 🔥 Training 47 | 1. Follow [Marigold](https://github.com/prs-eth/Marigold) to prepare depth training data ([Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/)), the default dataset structure is as follows: 48 | ``` 49 | datasets/ 50 | hypersim/ 51 | test/ 52 | train/ 53 | ai_001_001/ 54 | ... 55 | ai_055_010/ 56 | val/ 57 | vkitti/ 58 | depth/ 59 | Scene01/ 60 | ... 61 | Scene20/ 62 | rgb/ 63 | ``` 64 | 65 | 2. Download the pre-trained [PixArt-α](https://huggingface.co/PixArt-alpha/PixArt-XL-2-512x512) and [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev), then modify the pretrained_model_name_or_path. 66 | 3. Run the training command! 🚀 67 | ``` 68 | conda activate merge 69 | 70 | # Training MERGE-B model 71 | bash train_scripts/train_merge_b_depth.sh 72 | 73 | # Training MERGE-L model 74 | bash train_scripts/train_merge_l_depth.sh 75 | 76 | ``` 77 | --- 78 | ## 🕹️ Inference 79 | 1. Place your images in a directory, for example, under `/data` (where we have prepared several examples). 80 | 2. Run the inference command: 81 | ``` 82 | # for MERGE-B 83 | python inference_merge_base_depth.py --pretrained_model_path PATH/PixArt-XL-2-512x512 --model_weights PATH/merge_base_depth --image_path ./data/demo_1.png 84 | 85 | # for MERGE-L 86 | python inference_merge_large_depth.py --pretrained_model_path PATH/FLUX.1-dev --model_weights PATH/merge_large_depth --image_path ./data/demo_1.png 87 | ``` 88 | 89 | ### Choose your model 90 | Below are the released models and their corresponding configurations: 91 | |CHECKPOINT_DIR|PRETRAINED_MODEL|TASK_NAME| 92 | |:--:|:--:|:--:| 93 | | [`merge-base-depth-v1`](https://huggingface.co/hongk1998/merge-base-depth-v1) | PixArt-XL-2-512x512 | depth | 94 | | [`merge-large-depth-v1`](https://huggingface.co/hongk1998/merge-large-depth-v1) | FLUX.1-dev | depth | 95 | 96 | --- 97 | ## ⚖️ Main Results 98 | ### **Zero-shot Depth Estimation Results** 99 |
100 | 101 |
102 | 103 | ### **Zero-shot Normal Estimation Results** 104 |
105 | 106 |
107 | 108 | 109 | --- 110 | 111 | ## 📖BibTeX 112 | If you find this repository useful in your research, please consider giving a star ⭐ and a citation 113 | ``` 114 | @inproceedings{lin2025merge, 115 | title={More Than Generation: Unifying Generation and Depth Estimation via Text-to-Image Diffusion Models}, 116 | author={Lin, Hongkai and Liang, Dingkang and Mingyang Du and Xin Zhou and Bai, Xiang}, 117 | booktitle={Advances in Neural Information Processing Systems}, 118 | year={2025}, 119 | } 120 | ``` 121 | 122 | 123 | # 🤗Acknowledgements 124 | - Thanks to [Diffusers](https://github.com/huggingface/diffusers) for their wonderful technical support and awesome collaboration! 125 | - Thanks to [Hugging Face](https://github.com/huggingface) for sponsoring the nicely demo! 126 | - Thanks to [DiT](https://github.com/facebookresearch/DiT) for their wonderful work and codebase! 127 | - Thanks to [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) for their wonderful work and codebase! 128 | - Thanks to [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev), [Marigolod](https://github.com/prs-eth/Marigold) for their wonderful work! -------------------------------------------------------------------------------- /asset/images/depth_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/asset/images/depth_results.png -------------------------------------------------------------------------------- /asset/images/normal_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/asset/images/normal_results.png -------------------------------------------------------------------------------- /asset/images/teasor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/asset/images/teasor.png -------------------------------------------------------------------------------- /config/dataset/data_diode_all.yaml: -------------------------------------------------------------------------------- 1 | name: diode 2 | disp_name: diode_val_all 3 | dir: diode 4 | filenames: data_split/diode/diode_val_all_filename_list.txt -------------------------------------------------------------------------------- /config/dataset/data_eth3d.yaml: -------------------------------------------------------------------------------- 1 | name: eth3d 2 | disp_name: eth3d_full 3 | dir: eth3d 4 | filenames: data_split/eth3d/eth3d_filename_list.txt -------------------------------------------------------------------------------- /config/dataset/data_hypersim_train.yaml: -------------------------------------------------------------------------------- 1 | name: hypersim 2 | disp_name: hypersim_train 3 | dir: hypersim/train 4 | filenames: data_split/hypersim/filename_list_train_filtered.txt -------------------------------------------------------------------------------- /config/dataset/data_hypersim_val.yaml: -------------------------------------------------------------------------------- 1 | name: hypersim 2 | disp_name: hypersim_val 3 | dir: hypersim/val 4 | filenames: data_split/hypersim/filename_list_val_filtered.txt -------------------------------------------------------------------------------- /config/dataset/data_kitti_eigen_test.yaml: -------------------------------------------------------------------------------- 1 | name: kitti 2 | disp_name: kitti_eigen_test_full 3 | dir: kitti 4 | filenames: data_split/kitti/eigen_test_files_with_gt.txt 5 | kitti_bm_crop: true 6 | valid_mask_crop: eigen -------------------------------------------------------------------------------- /config/dataset/data_kitti_val.yaml: -------------------------------------------------------------------------------- 1 | name: kitti 2 | disp_name: kitti_val800_from_eigen_train 3 | dir: kitti 4 | filenames: data_split/kitti/eigen_val_from_train_800.txt 5 | kitti_bm_crop: true 6 | valid_mask_crop: eigen -------------------------------------------------------------------------------- /config/dataset/data_nyu_test.yaml: -------------------------------------------------------------------------------- 1 | name: nyu_v2 2 | disp_name: nyu_test_full 3 | dir: nyuv2 4 | filenames: data_split/nyu/labeled/filename_list_test.txt 5 | eigen_valid_mask: true -------------------------------------------------------------------------------- /config/dataset/data_nyu_train.yaml: -------------------------------------------------------------------------------- 1 | name: nyu_v2 2 | disp_name: nyu_train_full 3 | dir: nyuv2 4 | filenames: data_split/nyu/labeled/filename_list_train.txt 5 | eigen_valid_mask: true -------------------------------------------------------------------------------- /config/dataset/data_scannet_val.yaml: -------------------------------------------------------------------------------- 1 | name: scannet 2 | disp_name: scannet_val_800_1 3 | dir: scannet 4 | filenames: data_split/scannet/scannet_val_sampled_list_800_1.txt -------------------------------------------------------------------------------- /config/dataset/data_vkitti_train.yaml: -------------------------------------------------------------------------------- 1 | name: vkitti 2 | disp_name: vkitti_train 3 | dir: vkitti 4 | filenames: data_split/vkitti/vkitti_train.txt 5 | kitti_bm_crop: true 6 | valid_mask_crop: null # no valid_mask_crop for training -------------------------------------------------------------------------------- /config/dataset/data_vkitti_val.yaml: -------------------------------------------------------------------------------- 1 | name: vkitti 2 | disp_name: vkitti_val 3 | dir: vkitti/vkitti.tar 4 | filenames: data_split/vkitti/vkitti_val.txt 5 | kitti_bm_crop: true 6 | valid_mask_crop: eigen -------------------------------------------------------------------------------- /config/dataset/dataset_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train: 3 | name: mixed 4 | prob_ls: [0.9, 0.1] 5 | dataset_list: 6 | - name: hypersim 7 | disp_name: hypersim_train 8 | dir: hypersim/train 9 | filenames: data_split/hypersim/filename_list_train_filtered.txt 10 | resize_to_hw: 11 | - 480 12 | - 640 13 | - name: vkitti 14 | disp_name: vkitti_train 15 | dir: vkitti 16 | filenames: data_split/vkitti/vkitti_train.txt 17 | kitti_bm_crop: true 18 | valid_mask_crop: null 19 | 20 | -------------------------------------------------------------------------------- /config/dataset/dataset_train_hypr.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | train: 3 | name: hypersim 4 | disp_name: hypersim_train 5 | dir: hypersim/train 6 | filenames: data_split/hypersim/filename_list_train_filtered.txt 7 | resize_to_hw: 8 | - 480 9 | - 640 -------------------------------------------------------------------------------- /config/dataset/dataset_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | val: 3 | # - name: hypersim 4 | # disp_name: hypersim_val 5 | # dir: hypersim/hypersim_processed_val.tar 6 | # filenames: data_split/hypersim/filename_list_val_filtered.txt 7 | # resize_to_hw: 8 | # - 480 9 | # - 640 10 | 11 | # - name: nyu_v2 12 | # disp_name: nyu_train_full 13 | # dir: nyuv2/nyu_labeled_extracted.tar 14 | # filenames: data_split/nyu/labeled/filename_list_train.txt 15 | # eigen_valid_mask: true 16 | 17 | # - name: kitti 18 | # disp_name: kitti_val800_from_eigen_train 19 | # dir: kitti/kitti_sampled_val_800.tar 20 | # filenames: data_split/kitti/eigen_val_from_train_800.txt 21 | # kitti_bm_crop: true 22 | # valid_mask_crop: eigen 23 | 24 | # Smaller subsets for faster validation during training 25 | # The first dataset is used to calculate main eval metric. 26 | # - name: hypersim 27 | # disp_name: hypersim_val_small_80 28 | # dir: hypersim/hypersim_processed_val.tar 29 | # filenames: data_split/hypersim/filename_list_val_filtered_small_80.txt 30 | # resize_to_hw: 31 | # - 480 32 | # - 640 33 | # 34 | - name: nyu_v2 35 | disp_name: nyu_train_small_100 36 | dir: nyuv2 37 | filenames: data_split/nyu/labeled/filename_list_train_small_100.txt 38 | eigen_valid_mask: true 39 | task: val 40 | # 41 | # - name: kitti 42 | # disp_name: kitti_val_from_train_sub_100 43 | # dir: kitti/kitti_sampled_val_800.tar 44 | # filenames: data_split/kitti/eigen_val_from_train_sub_100.txt 45 | # kitti_bm_crop: true 46 | # valid_mask_crop: eigen -------------------------------------------------------------------------------- /config/dataset/dataset_vis.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | vis: 3 | - name: hypersim 4 | disp_name: hypersim_vis 5 | dir: hypersim/val 6 | filenames: data_split/hypersim/selected_vis_sample.txt 7 | resize_to_hw: 8 | - 480 9 | - 640 10 | -------------------------------------------------------------------------------- /config/logging.yaml: -------------------------------------------------------------------------------- 1 | logging: 2 | filename: logging.log 3 | format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s' 4 | console_level: 20 5 | file_level: 10 6 | -------------------------------------------------------------------------------- /config/train_merge.yaml: -------------------------------------------------------------------------------- 1 | base_config: 2 | - config/logging.yaml 3 | - config/wandb.yaml 4 | - config/dataset/dataset_train.yaml 5 | - config/dataset/dataset_val.yaml 6 | - config/dataset/dataset_vis.yaml 7 | 8 | 9 | depth_normalization: 10 | type: scale_shift_depth 11 | clip: true 12 | norm_min: -1.0 13 | norm_max: 1.0 14 | min_max_quantile: 0.02 15 | 16 | augmentation: 17 | lr_flip_p: 0.5 18 | 19 | dataloader: 20 | num_workers: 2 21 | effective_batch_size: 32 22 | max_train_batch_size: 2 23 | seed: 2024 # to ensure continuity when resuming from checkpoint 24 | 25 | # Training settings 26 | multi_res_noise: 27 | strength: 0.9 28 | annealed: true 29 | downscale_strategy: original 30 | 31 | gt_depth_type: depth_raw_norm 32 | gt_mask_type: valid_mask_raw 33 | 34 | # Validation (and visualization) settings 35 | validation: 36 | denoising_steps: 50 37 | ensemble_size: 1 # simplified setting for on-training validation 38 | processing_res: 0 39 | match_input_res: false 40 | resample_method: bilinear 41 | main_val_metric: abs_relative_difference 42 | main_val_metric_goal: minimize 43 | init_seed: 2024 44 | 45 | eval: 46 | alignment: least_square 47 | align_max_res: null 48 | eval_metrics: 49 | - abs_relative_difference 50 | - squared_relative_difference 51 | - rmse_linear 52 | - rmse_log 53 | - log10 54 | - delta1_acc 55 | - delta2_acc 56 | - delta3_acc 57 | - i_rmse 58 | - silog_rmse 59 | -------------------------------------------------------------------------------- /config/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | # entity: your_entity 3 | project: merge -------------------------------------------------------------------------------- /data/demo_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/data/demo_1.png -------------------------------------------------------------------------------- /data/val_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/data/val_01.png -------------------------------------------------------------------------------- /data/val_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/data/val_02.png -------------------------------------------------------------------------------- /data_split/hypersim/filename_list_val_filtered_small_80.txt: -------------------------------------------------------------------------------- 1 | ai_003_010/rgb_cam_00_fr0047.png ai_003_010/depth_plane_cam_00_fr0047.png 2 | ai_003_010/rgb_cam_00_fr0048.png ai_003_010/depth_plane_cam_00_fr0048.png 3 | ai_003_010/rgb_cam_01_fr0098.png ai_003_010/depth_plane_cam_01_fr0098.png 4 | ai_004_003/rgb_cam_01_fr0008.png ai_004_003/depth_plane_cam_01_fr0008.png 5 | ai_004_004/rgb_cam_00_fr0025.png ai_004_004/depth_plane_cam_00_fr0025.png 6 | ai_004_004/rgb_cam_00_fr0046.png ai_004_004/depth_plane_cam_00_fr0046.png 7 | ai_004_004/rgb_cam_00_fr0049.png ai_004_004/depth_plane_cam_00_fr0049.png 8 | ai_004_004/rgb_cam_01_fr0023.png ai_004_004/depth_plane_cam_01_fr0023.png 9 | ai_005_005/rgb_cam_00_fr0032.png ai_005_005/depth_plane_cam_00_fr0032.png 10 | ai_006_007/rgb_cam_00_fr0022.png ai_006_007/depth_plane_cam_00_fr0022.png 11 | ai_006_007/rgb_cam_00_fr0095.png ai_006_007/depth_plane_cam_00_fr0095.png 12 | ai_007_001/rgb_cam_00_fr0044.png ai_007_001/depth_plane_cam_00_fr0044.png 13 | ai_007_001/rgb_cam_00_fr0048.png ai_007_001/depth_plane_cam_00_fr0048.png 14 | ai_009_007/rgb_cam_00_fr0017.png ai_009_007/depth_plane_cam_00_fr0017.png 15 | ai_009_007/rgb_cam_00_fr0097.png ai_009_007/depth_plane_cam_00_fr0097.png 16 | ai_009_009/rgb_cam_00_fr0094.png ai_009_009/depth_plane_cam_00_fr0094.png 17 | ai_015_001/rgb_cam_00_fr0058.png ai_015_001/depth_plane_cam_00_fr0058.png 18 | ai_015_001/rgb_cam_00_fr0089.png ai_015_001/depth_plane_cam_00_fr0089.png 19 | ai_017_007/rgb_cam_01_fr0064.png ai_017_007/depth_plane_cam_01_fr0064.png 20 | ai_018_005/rgb_cam_00_fr0014.png ai_018_005/depth_plane_cam_00_fr0014.png 21 | ai_018_005/rgb_cam_00_fr0059.png ai_018_005/depth_plane_cam_00_fr0059.png 22 | ai_022_010/rgb_cam_00_fr0097.png ai_022_010/depth_plane_cam_00_fr0097.png 23 | ai_022_010/rgb_cam_00_fr0099.png ai_022_010/depth_plane_cam_00_fr0099.png 24 | ai_023_003/rgb_cam_00_fr0013.png ai_023_003/depth_plane_cam_00_fr0013.png 25 | ai_023_003/rgb_cam_00_fr0015.png ai_023_003/depth_plane_cam_00_fr0015.png 26 | ai_023_003/rgb_cam_00_fr0036.png ai_023_003/depth_plane_cam_00_fr0036.png 27 | ai_023_003/rgb_cam_00_fr0095.png ai_023_003/depth_plane_cam_00_fr0095.png 28 | ai_023_003/rgb_cam_01_fr0029.png ai_023_003/depth_plane_cam_01_fr0029.png 29 | ai_023_003/rgb_cam_01_fr0036.png ai_023_003/depth_plane_cam_01_fr0036.png 30 | ai_023_003/rgb_cam_01_fr0071.png ai_023_003/depth_plane_cam_01_fr0071.png 31 | ai_032_007/rgb_cam_00_fr0031.png ai_032_007/depth_plane_cam_00_fr0031.png 32 | ai_032_007/rgb_cam_00_fr0040.png ai_032_007/depth_plane_cam_00_fr0040.png 33 | ai_032_007/rgb_cam_00_fr0075.png ai_032_007/depth_plane_cam_00_fr0075.png 34 | ai_035_003/rgb_cam_00_fr0054.png ai_035_003/depth_plane_cam_00_fr0054.png 35 | ai_035_004/rgb_cam_00_fr0077.png ai_035_004/depth_plane_cam_00_fr0077.png 36 | ai_038_009/rgb_cam_00_fr0031.png ai_038_009/depth_plane_cam_00_fr0031.png 37 | ai_038_009/rgb_cam_01_fr0010.png ai_038_009/depth_plane_cam_01_fr0010.png 38 | ai_038_009/rgb_cam_01_fr0088.png ai_038_009/depth_plane_cam_01_fr0088.png 39 | ai_039_003/rgb_cam_01_fr0042.png ai_039_003/depth_plane_cam_01_fr0042.png 40 | ai_039_003/rgb_cam_01_fr0097.png ai_039_003/depth_plane_cam_01_fr0097.png 41 | ai_044_001/rgb_cam_00_fr0043.png ai_044_001/depth_plane_cam_00_fr0043.png 42 | ai_044_001/rgb_cam_01_fr0018.png ai_044_001/depth_plane_cam_01_fr0018.png 43 | ai_044_003/rgb_cam_01_fr0082.png ai_044_003/depth_plane_cam_01_fr0082.png 44 | ai_044_003/rgb_cam_01_fr0087.png ai_044_003/depth_plane_cam_01_fr0087.png 45 | ai_044_003/rgb_cam_02_fr0086.png ai_044_003/depth_plane_cam_02_fr0086.png 46 | ai_044_003/rgb_cam_03_fr0022.png ai_044_003/depth_plane_cam_03_fr0022.png 47 | ai_044_003/rgb_cam_03_fr0063.png ai_044_003/depth_plane_cam_03_fr0063.png 48 | ai_045_008/rgb_cam_00_fr0015.png ai_045_008/depth_plane_cam_00_fr0015.png 49 | ai_045_008/rgb_cam_00_fr0030.png ai_045_008/depth_plane_cam_00_fr0030.png 50 | ai_045_008/rgb_cam_01_fr0029.png ai_045_008/depth_plane_cam_01_fr0029.png 51 | ai_045_008/rgb_cam_01_fr0052.png ai_045_008/depth_plane_cam_01_fr0052.png 52 | ai_045_008/rgb_cam_01_fr0088.png ai_045_008/depth_plane_cam_01_fr0088.png 53 | ai_047_009/rgb_cam_00_fr0097.png ai_047_009/depth_plane_cam_00_fr0097.png 54 | ai_048_001/rgb_cam_00_fr0014.png ai_048_001/depth_plane_cam_00_fr0014.png 55 | ai_048_001/rgb_cam_00_fr0088.png ai_048_001/depth_plane_cam_00_fr0088.png 56 | ai_048_001/rgb_cam_01_fr0045.png ai_048_001/depth_plane_cam_01_fr0045.png 57 | ai_048_001/rgb_cam_02_fr0031.png ai_048_001/depth_plane_cam_02_fr0031.png 58 | ai_048_001/rgb_cam_03_fr0005.png ai_048_001/depth_plane_cam_03_fr0005.png 59 | ai_048_001/rgb_cam_03_fr0045.png ai_048_001/depth_plane_cam_03_fr0045.png 60 | ai_048_001/rgb_cam_03_fr0054.png ai_048_001/depth_plane_cam_03_fr0054.png 61 | ai_048_001/rgb_cam_03_fr0061.png ai_048_001/depth_plane_cam_03_fr0061.png 62 | ai_050_002/rgb_cam_01_fr0016.png ai_050_002/depth_plane_cam_01_fr0016.png 63 | ai_050_002/rgb_cam_02_fr0053.png ai_050_002/depth_plane_cam_02_fr0053.png 64 | ai_050_002/rgb_cam_03_fr0082.png ai_050_002/depth_plane_cam_03_fr0082.png 65 | ai_050_002/rgb_cam_04_fr0033.png ai_050_002/depth_plane_cam_04_fr0033.png 66 | ai_051_004/rgb_cam_00_fr0028.png ai_051_004/depth_plane_cam_00_fr0028.png 67 | ai_051_004/rgb_cam_01_fr0065.png ai_051_004/depth_plane_cam_01_fr0065.png 68 | ai_051_004/rgb_cam_02_fr0054.png ai_051_004/depth_plane_cam_02_fr0054.png 69 | ai_051_004/rgb_cam_02_fr0056.png ai_051_004/depth_plane_cam_02_fr0056.png 70 | ai_051_004/rgb_cam_03_fr0037.png ai_051_004/depth_plane_cam_03_fr0037.png 71 | ai_051_004/rgb_cam_04_fr0083.png ai_051_004/depth_plane_cam_04_fr0083.png 72 | ai_051_004/rgb_cam_05_fr0003.png ai_051_004/depth_plane_cam_05_fr0003.png 73 | ai_052_001/rgb_cam_00_fr0008.png ai_052_001/depth_plane_cam_00_fr0008.png 74 | ai_052_003/rgb_cam_00_fr0097.png ai_052_003/depth_plane_cam_00_fr0097.png 75 | ai_052_003/rgb_cam_01_fr0081.png ai_052_003/depth_plane_cam_01_fr0081.png 76 | ai_052_007/rgb_cam_01_fr0001.png ai_052_007/depth_plane_cam_01_fr0001.png 77 | ai_053_003/rgb_cam_00_fr0005.png ai_053_003/depth_plane_cam_00_fr0005.png 78 | ai_053_005/rgb_cam_00_fr0080.png ai_053_005/depth_plane_cam_00_fr0080.png 79 | ai_055_009/rgb_cam_01_fr0070.png ai_055_009/depth_plane_cam_01_fr0070.png 80 | ai_055_009/rgb_cam_01_fr0086.png ai_055_009/depth_plane_cam_01_fr0086.png -------------------------------------------------------------------------------- /data_split/hypersim/selected_vis_sample.txt: -------------------------------------------------------------------------------- 1 | ai_015_004/rgb_cam_00_fr0002.png ai_015_004/depth_plane_cam_00_fr0002.png (val) 2 | ai_044_003/rgb_cam_01_fr0063.png ai_044_003/depth_plane_cam_01_fr0063.png (val) 3 | ai_052_003/rgb_cam_01_fr0076.png ai_052_003/depth_plane_cam_01_fr0076.png (val) -------------------------------------------------------------------------------- /data_split/kitti/eigen_val_from_train_sub_100.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000046.png 2011_09_26_drive_0001_sync/proj_depth/groundtruth/image_02/0000000046.png 721.5377 2 | 2011_09_26/2011_09_26_drive_0005_sync/image_02/data/0000000148.png 2011_09_26_drive_0005_sync/proj_depth/groundtruth/image_02/0000000148.png 721.5377 3 | 2011_09_26/2011_09_26_drive_0014_sync/image_02/data/0000000076.png 2011_09_26_drive_0014_sync/proj_depth/groundtruth/image_02/0000000076.png 721.5377 4 | 2011_09_26/2011_09_26_drive_0015_sync/image_02/data/0000000019.png 2011_09_26_drive_0015_sync/proj_depth/groundtruth/image_02/0000000019.png 721.5377 5 | 2011_09_26/2011_09_26_drive_0015_sync/image_02/data/0000000194.png 2011_09_26_drive_0015_sync/proj_depth/groundtruth/image_02/0000000194.png 721.5377 6 | 2011_09_26/2011_09_26_drive_0018_sync/image_02/data/0000000106.png 2011_09_26_drive_0018_sync/proj_depth/groundtruth/image_02/0000000106.png 721.5377 7 | 2011_09_26/2011_09_26_drive_0019_sync/image_02/data/0000000263.png 2011_09_26_drive_0019_sync/proj_depth/groundtruth/image_02/0000000263.png 721.5377 8 | 2011_09_26/2011_09_26_drive_0019_sync/image_02/data/0000000274.png 2011_09_26_drive_0019_sync/proj_depth/groundtruth/image_02/0000000274.png 721.5377 9 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000015.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000015.png 721.5377 10 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000123.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000123.png 721.5377 11 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000149.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000149.png 721.5377 12 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000308.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000308.png 721.5377 13 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000553.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000553.png 721.5377 14 | 2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000691.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000691.png 721.5377 15 | 2011_09_26/2011_09_26_drive_0028_sync/image_02/data/0000000270.png 2011_09_26_drive_0028_sync/proj_depth/groundtruth/image_02/0000000270.png 721.5377 16 | 2011_09_26/2011_09_26_drive_0035_sync/image_02/data/0000000085.png 2011_09_26_drive_0035_sync/proj_depth/groundtruth/image_02/0000000085.png 721.5377 17 | 2011_09_26/2011_09_26_drive_0039_sync/image_02/data/0000000326.png 2011_09_26_drive_0039_sync/proj_depth/groundtruth/image_02/0000000326.png 721.5377 18 | 2011_09_26/2011_09_26_drive_0051_sync/image_02/data/0000000429.png 2011_09_26_drive_0051_sync/proj_depth/groundtruth/image_02/0000000429.png 721.5377 19 | 2011_09_26/2011_09_26_drive_0057_sync/image_02/data/0000000010.png 2011_09_26_drive_0057_sync/proj_depth/groundtruth/image_02/0000000010.png 721.5377 20 | 2011_09_26/2011_09_26_drive_0060_sync/image_02/data/0000000020.png 2011_09_26_drive_0060_sync/proj_depth/groundtruth/image_02/0000000020.png 721.5377 21 | 2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000223.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000223.png 721.5377 22 | 2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000262.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000262.png 721.5377 23 | 2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000291.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000291.png 721.5377 24 | 2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000523.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000523.png 721.5377 25 | 2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000524.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000524.png 721.5377 26 | 2011_09_26/2011_09_26_drive_0070_sync/image_02/data/0000000063.png 2011_09_26_drive_0070_sync/proj_depth/groundtruth/image_02/0000000063.png 721.5377 27 | 2011_09_26/2011_09_26_drive_0070_sync/image_02/data/0000000320.png 2011_09_26_drive_0070_sync/proj_depth/groundtruth/image_02/0000000320.png 721.5377 28 | 2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000313.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000313.png 721.5377 29 | 2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000316.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000316.png 721.5377 30 | 2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000363.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000363.png 721.5377 31 | 2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000438.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000438.png 721.5377 32 | 2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000137.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000137.png 721.5377 33 | 2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000143.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000143.png 721.5377 34 | 2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000278.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000278.png 721.5377 35 | 2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000312.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000312.png 721.5377 36 | 2011_09_26/2011_09_26_drive_0095_sync/image_02/data/0000000160.png 2011_09_26_drive_0095_sync/proj_depth/groundtruth/image_02/0000000160.png 721.5377 37 | 2011_09_26/2011_09_26_drive_0104_sync/image_02/data/0000000011.png 2011_09_26_drive_0104_sync/proj_depth/groundtruth/image_02/0000000011.png 721.5377 38 | 2011_09_26/2011_09_26_drive_0113_sync/image_02/data/0000000052.png 2011_09_26_drive_0113_sync/proj_depth/groundtruth/image_02/0000000052.png 721.5377 39 | 2011_09_26/2011_09_26_drive_0113_sync/image_02/data/0000000055.png 2011_09_26_drive_0113_sync/proj_depth/groundtruth/image_02/0000000055.png 721.5377 40 | 2011_09_29/2011_09_29_drive_0004_sync/image_02/data/0000000065.png 2011_09_29_drive_0004_sync/proj_depth/groundtruth/image_02/0000000065.png 718.3351 41 | 2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000000325.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000000325.png 707.0912 42 | 2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000000959.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000000959.png 707.0912 43 | 2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000001004.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000001004.png 707.0912 44 | 2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000001054.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000001054.png 707.0912 45 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000000545.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000000545.png 707.0912 46 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000000920.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000000920.png 707.0912 47 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001593.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001593.png 707.0912 48 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001692.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001692.png 707.0912 49 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001806.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001806.png 707.0912 50 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001905.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001905.png 707.0912 51 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002714.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002714.png 707.0912 52 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002812.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002812.png 707.0912 53 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002838.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002838.png 707.0912 54 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000003402.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000003402.png 707.0912 55 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000003700.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000003700.png 707.0912 56 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004016.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004016.png 707.0912 57 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004276.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004276.png 707.0912 58 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004664.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004664.png 707.0912 59 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004772.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004772.png 707.0912 60 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004782.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004782.png 707.0912 61 | 2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000005095.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000005095.png 707.0912 62 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000319.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000319.png 707.0912 63 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000355.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000355.png 707.0912 64 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000500.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000500.png 707.0912 65 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000682.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000682.png 707.0912 66 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000710.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000710.png 707.0912 67 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000896.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000896.png 707.0912 68 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001197.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001197.png 707.0912 69 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001508.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001508.png 707.0912 70 | 2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001512.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001512.png 707.0912 71 | 2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000029.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000029.png 707.0912 72 | 2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000171.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000171.png 707.0912 73 | 2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000193.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000193.png 707.0912 74 | 2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000389.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000389.png 707.0912 75 | 2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000001141.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000001141.png 707.0912 76 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000000138.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000000138.png 718.856 77 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000000593.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000000593.png 718.856 78 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001046.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001046.png 718.856 79 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001151.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001151.png 718.856 80 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001255.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001255.png 718.856 81 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001283.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001283.png 718.856 82 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001737.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001737.png 718.856 83 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001999.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001999.png 718.856 84 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002012.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002012.png 718.856 85 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002089.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002089.png 718.856 86 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002324.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002324.png 718.856 87 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002902.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002902.png 718.856 88 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002971.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002971.png 718.856 89 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003299.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003299.png 718.856 90 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003366.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003366.png 718.856 91 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003427.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003427.png 718.856 92 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003440.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003440.png 718.856 93 | 2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000004060.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000004060.png 718.856 94 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000525.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000525.png 718.856 95 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000538.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000538.png 718.856 96 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000648.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000648.png 718.856 97 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000776.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000776.png 718.856 98 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000779.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000779.png 718.856 99 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000001087.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000001087.png 718.856 100 | 2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000001107.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000001107.png 718.856 -------------------------------------------------------------------------------- /data_split/nyu/labeled/filename_list_train_small_100.txt: -------------------------------------------------------------------------------- 1 | train/bathroom_0007/rgb_0649.png train/bathroom_0007/depth_0649.png train/bathroom_0007/filled_0649.png 2 | train/bathroom_0010/rgb_0653.png train/bathroom_0010/depth_0653.png train/bathroom_0010/filled_0653.png 3 | train/bathroom_0041/rgb_0719.png train/bathroom_0041/depth_0719.png train/bathroom_0041/filled_0719.png 4 | train/bathroom_0045/rgb_0729.png train/bathroom_0045/depth_0729.png train/bathroom_0045/filled_0729.png 5 | train/bathroom_0048/rgb_0736.png train/bathroom_0048/depth_0736.png train/bathroom_0048/filled_0736.png 6 | train/bathroom_0056/rgb_0505.png train/bathroom_0056/depth_0505.png train/bathroom_0056/filled_0505.png 7 | train/bedroom_0004/rgb_0178.png train/bedroom_0004/depth_0178.png train/bedroom_0004/filled_0178.png 8 | train/bedroom_0016/rgb_0071.png train/bedroom_0016/depth_0071.png train/bedroom_0016/filled_0071.png 9 | train/bedroom_0025/rgb_0910.png train/bedroom_0025/depth_0910.png train/bedroom_0025/filled_0910.png 10 | train/bedroom_0026/rgb_0914.png train/bedroom_0026/depth_0914.png train/bedroom_0026/filled_0914.png 11 | train/bedroom_0031/rgb_0929.png train/bedroom_0031/depth_0929.png train/bedroom_0031/filled_0929.png 12 | train/bedroom_0034/rgb_0939.png train/bedroom_0034/depth_0939.png train/bedroom_0034/filled_0939.png 13 | train/bedroom_0040/rgb_0954.png train/bedroom_0040/depth_0954.png train/bedroom_0040/filled_0954.png 14 | train/bedroom_0042/rgb_0958.png train/bedroom_0042/depth_0958.png train/bedroom_0042/filled_0958.png 15 | train/bedroom_0050/rgb_0978.png train/bedroom_0050/depth_0978.png train/bedroom_0050/filled_0978.png 16 | train/bedroom_0051/rgb_0984.png train/bedroom_0051/depth_0984.png train/bedroom_0051/filled_0984.png 17 | train/bedroom_0056/rgb_0997.png train/bedroom_0056/depth_0997.png train/bedroom_0056/filled_0997.png 18 | train/bedroom_0060/rgb_1008.png train/bedroom_0060/depth_1008.png train/bedroom_0060/filled_1008.png 19 | train/bedroom_0067/rgb_1029.png train/bedroom_0067/depth_1029.png train/bedroom_0067/filled_1029.png 20 | train/bedroom_0072/rgb_1045.png train/bedroom_0072/depth_1045.png train/bedroom_0072/filled_1045.png 21 | train/bedroom_0072/rgb_1046.png train/bedroom_0072/depth_1046.png train/bedroom_0072/filled_1046.png 22 | train/bedroom_0079/rgb_1062.png train/bedroom_0079/depth_1062.png train/bedroom_0079/filled_1062.png 23 | train/bedroom_0081/rgb_1072.png train/bedroom_0081/depth_1072.png train/bedroom_0081/filled_1072.png 24 | train/bedroom_0096/rgb_1112.png train/bedroom_0096/depth_1112.png train/bedroom_0096/filled_1112.png 25 | train/bedroom_0118/rgb_1173.png train/bedroom_0118/depth_1173.png train/bedroom_0118/filled_1173.png 26 | train/bedroom_0129/rgb_1197.png train/bedroom_0129/depth_1197.png train/bedroom_0129/filled_1197.png 27 | train/bedroom_0136/rgb_0527.png train/bedroom_0136/depth_0527.png train/bedroom_0136/filled_0527.png 28 | train/bookstore_0000/rgb_0105.png train/bookstore_0000/depth_0105.png train/bookstore_0000/filled_0105.png 29 | train/bookstore_0000/rgb_0107.png train/bookstore_0000/depth_0107.png train/bookstore_0000/filled_0107.png 30 | train/bookstore_0002/rgb_0101.png train/bookstore_0002/depth_0101.png train/bookstore_0002/filled_0101.png 31 | train/bookstore_0002/rgb_0103.png train/bookstore_0002/depth_0103.png train/bookstore_0002/filled_0103.png 32 | train/classroom_0010/rgb_0305.png train/classroom_0010/depth_0305.png train/classroom_0010/filled_0305.png 33 | train/classroom_0012/rgb_0309.png train/classroom_0012/depth_0309.png train/classroom_0012/filled_0309.png 34 | train/conference_room_0001/rgb_0339.png train/conference_room_0001/depth_0339.png train/conference_room_0001/filled_0339.png 35 | train/conference_room_0001/rgb_0341.png train/conference_room_0001/depth_0341.png train/conference_room_0001/filled_0341.png 36 | train/dining_room_0002/rgb_1346.png train/dining_room_0002/depth_1346.png train/dining_room_0002/filled_1346.png 37 | train/dining_room_0008/rgb_1363.png train/dining_room_0008/depth_1363.png train/dining_room_0008/filled_1363.png 38 | train/dining_room_0012/rgb_1371.png train/dining_room_0012/depth_1371.png train/dining_room_0012/filled_1371.png 39 | train/dining_room_0014/rgb_1377.png train/dining_room_0014/depth_1377.png train/dining_room_0014/filled_1377.png 40 | train/dining_room_0015/rgb_1379.png train/dining_room_0015/depth_1379.png train/dining_room_0015/filled_1379.png 41 | train/dining_room_0016/rgb_1382.png train/dining_room_0016/depth_1382.png train/dining_room_0016/filled_1382.png 42 | train/dining_room_0031/rgb_1425.png train/dining_room_0031/depth_1425.png train/dining_room_0031/filled_1425.png 43 | train/dining_room_0031/rgb_1426.png train/dining_room_0031/depth_1426.png train/dining_room_0031/filled_1426.png 44 | train/dining_room_0033/rgb_1436.png train/dining_room_0033/depth_1436.png train/dining_room_0033/filled_1436.png 45 | train/dining_room_0037/rgb_0548.png train/dining_room_0037/depth_0548.png train/dining_room_0037/filled_0548.png 46 | train/furniture_store_0001/rgb_0224.png train/furniture_store_0001/depth_0224.png train/furniture_store_0001/filled_0224.png 47 | train/furniture_store_0001/rgb_0237.png train/furniture_store_0001/depth_0237.png train/furniture_store_0001/filled_0237.png 48 | train/furniture_store_0002/rgb_0249.png train/furniture_store_0002/depth_0249.png train/furniture_store_0002/filled_0249.png 49 | train/home_office_0005/rgb_0368.png train/home_office_0005/depth_0368.png train/home_office_0005/filled_0368.png 50 | train/home_office_0006/rgb_0374.png train/home_office_0006/depth_0374.png train/home_office_0006/filled_0374.png 51 | train/home_office_0008/rgb_0380.png train/home_office_0008/depth_0380.png train/home_office_0008/filled_0380.png 52 | train/home_office_0013/rgb_0554.png train/home_office_0013/depth_0554.png train/home_office_0013/filled_0554.png 53 | train/kitchen_0010/rgb_0138.png train/kitchen_0010/depth_0138.png train/kitchen_0010/filled_0138.png 54 | train/kitchen_0019/rgb_0750.png train/kitchen_0019/depth_0750.png train/kitchen_0019/filled_0750.png 55 | train/kitchen_0019/rgb_0757.png train/kitchen_0019/depth_0757.png train/kitchen_0019/filled_0757.png 56 | train/kitchen_0028/rgb_0788.png train/kitchen_0028/depth_0788.png train/kitchen_0028/filled_0788.png 57 | train/kitchen_0028/rgb_0793.png train/kitchen_0028/depth_0793.png train/kitchen_0028/filled_0793.png 58 | train/kitchen_0029/rgb_0799.png train/kitchen_0029/depth_0799.png train/kitchen_0029/filled_0799.png 59 | train/kitchen_0033/rgb_0815.png train/kitchen_0033/depth_0815.png train/kitchen_0033/filled_0815.png 60 | train/kitchen_0033/rgb_0816.png train/kitchen_0033/depth_0816.png train/kitchen_0033/filled_0816.png 61 | train/kitchen_0037/rgb_0832.png train/kitchen_0037/depth_0832.png train/kitchen_0037/filled_0832.png 62 | train/kitchen_0041/rgb_0849.png train/kitchen_0041/depth_0849.png train/kitchen_0041/filled_0849.png 63 | train/kitchen_0047/rgb_0875.png train/kitchen_0047/depth_0875.png train/kitchen_0047/filled_0875.png 64 | train/kitchen_0050/rgb_0887.png train/kitchen_0050/depth_0887.png train/kitchen_0050/filled_0887.png 65 | train/kitchen_0051/rgb_0892.png train/kitchen_0051/depth_0892.png train/kitchen_0051/filled_0892.png 66 | train/kitchen_0051/rgb_0893.png train/kitchen_0051/depth_0893.png train/kitchen_0051/filled_0893.png 67 | train/kitchen_0052/rgb_0899.png train/kitchen_0052/depth_0899.png train/kitchen_0052/filled_0899.png 68 | train/kitchen_0059/rgb_0573.png train/kitchen_0059/depth_0573.png train/kitchen_0059/filled_0573.png 69 | train/living_room_0000/rgb_0050.png train/living_room_0000/depth_0050.png train/living_room_0000/filled_0050.png 70 | train/living_room_0010/rgb_0156.png train/living_room_0010/depth_0156.png train/living_room_0010/filled_0156.png 71 | train/living_room_0010/rgb_0158.png train/living_room_0010/depth_0158.png train/living_room_0010/filled_0158.png 72 | train/living_room_0010/rgb_0159.png train/living_room_0010/depth_0159.png train/living_room_0010/filled_0159.png 73 | train/living_room_0011/rgb_0162.png train/living_room_0011/depth_0162.png train/living_room_0011/filled_0162.png 74 | train/living_room_0019/rgb_0258.png train/living_room_0019/depth_0258.png train/living_room_0019/filled_0258.png 75 | train/living_room_0042/rgb_1251.png train/living_room_0042/depth_1251.png train/living_room_0042/filled_1251.png 76 | train/living_room_0046/rgb_1268.png train/living_room_0046/depth_1268.png train/living_room_0046/filled_1268.png 77 | train/living_room_0047/rgb_1272.png train/living_room_0047/depth_1272.png train/living_room_0047/filled_1272.png 78 | train/living_room_0058/rgb_1301.png train/living_room_0058/depth_1301.png train/living_room_0058/filled_1301.png 79 | train/living_room_0062/rgb_1310.png train/living_room_0062/depth_1310.png train/living_room_0062/filled_1310.png 80 | train/living_room_0063/rgb_1313.png train/living_room_0063/depth_1313.png train/living_room_0063/filled_1313.png 81 | train/living_room_0083/rgb_0588.png train/living_room_0083/depth_0588.png train/living_room_0083/filled_0588.png 82 | train/living_room_0086/rgb_0601.png train/living_room_0086/depth_0601.png train/living_room_0086/filled_0601.png 83 | train/office_0003/rgb_0004.png train/office_0003/depth_0004.png train/office_0003/filled_0004.png 84 | train/office_0023/rgb_0623.png train/office_0023/depth_0623.png train/office_0023/filled_0623.png 85 | train/office_0024/rgb_0627.png train/office_0024/depth_0627.png train/office_0024/filled_0627.png 86 | train/office_kitchen_0003/rgb_0415.png train/office_kitchen_0003/depth_0415.png train/office_kitchen_0003/filled_0415.png 87 | train/playroom_0002/rgb_0418.png train/playroom_0002/depth_0418.png train/playroom_0002/filled_0418.png 88 | train/playroom_0003/rgb_0423.png train/playroom_0003/depth_0423.png train/playroom_0003/filled_0423.png 89 | train/playroom_0003/rgb_0424.png train/playroom_0003/depth_0424.png train/playroom_0003/filled_0424.png 90 | train/playroom_0003/rgb_0425.png train/playroom_0003/depth_0425.png train/playroom_0003/filled_0425.png 91 | train/playroom_0004/rgb_0426.png train/playroom_0004/depth_0426.png train/playroom_0004/filled_0426.png 92 | train/printer_room_0001/rgb_0451.png train/printer_room_0001/depth_0451.png train/printer_room_0001/filled_0451.png 93 | train/reception_room_0001/rgb_0456.png train/reception_room_0001/depth_0456.png train/reception_room_0001/filled_0456.png 94 | train/reception_room_0002/rgb_0459.png train/reception_room_0002/depth_0459.png train/reception_room_0002/filled_0459.png 95 | train/reception_room_0004/rgb_0468.png train/reception_room_0004/depth_0468.png train/reception_room_0004/filled_0468.png 96 | train/student_lounge_0001/rgb_0641.png train/student_lounge_0001/depth_0641.png train/student_lounge_0001/filled_0641.png 97 | train/study_0003/rgb_0478.png train/study_0003/depth_0478.png train/study_0003/filled_0478.png 98 | train/study_0005/rgb_0485.png train/study_0005/depth_0485.png train/study_0005/filled_0485.png 99 | train/study_0008/rgb_0646.png train/study_0008/depth_0646.png train/study_0008/filled_0646.png 100 | train/study_room_0004/rgb_0274.png train/study_room_0004/depth_0274.png train/study_room_0004/filled_0274.png -------------------------------------------------------------------------------- /inference_merge_base_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import os 5 | import sys 6 | from matplotlib import pyplot as plt 7 | import argparse 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | from merge.pipeline.merge_transformer import xTransformerModel, MERGEPixArtTransformer 12 | from merge.pipeline.pipeline_merge import MERGEPixArtPipeline 13 | 14 | cmap = plt.get_cmap('Spectral') 15 | 16 | def main(args): 17 | 18 | weight_dtype = torch.float32 19 | 20 | fixed_transformer = xTransformerModel.from_pretrained( 21 | args.pretrained_model_path, 22 | subfolder="transformer", torch_dtype=weight_dtype 23 | ) 24 | fixed_transformer.requires_grad_(False) 25 | 26 | depth_converters = xTransformerModel.from_pretrained( 27 | args.model_weights, 28 | subfolder="depth_converters", 29 | torch_dtype=weight_dtype 30 | ) 31 | depth_converters.requires_grad_(False) 32 | 33 | merge_transformer = MERGEPixArtTransformer(fixed_transformer, depth_converters) 34 | del fixed_transformer, depth_converters 35 | 36 | merge_model = MERGEPixArtPipeline.from_pretrained( 37 | args.pretrained_model_path, 38 | transformer=merge_transformer, 39 | torch_dtype=weight_dtype, 40 | use_safetensors=True 41 | ).to("cuda") 42 | 43 | # for depth estimation 44 | image = Image.open(args.image_path) 45 | width, height = image.size 46 | depth_image = merge_model( 47 | image=image, 48 | prompt='', 49 | num_inference_steps=args.denoising_step, 50 | height=height, 51 | width=width, 52 | mode='merge' 53 | ).images 54 | depth_image = torch.mean(depth_image, dim=1).squeeze().cpu().numpy() 55 | depth_image = (cmap(depth_image) * 255).astype(np.uint8) 56 | Image.fromarray(depth_image).save("./merge_base_depth_demo.png") 57 | 58 | # for text-to-image 59 | image = merge_model( 60 | prompt=args.prompt, 61 | num_inference_steps=args.denoising_step, 62 | guidance_scale=4.5, 63 | mode='t2i', 64 | ).images[0] 65 | image.save("./merge_base_t2i_demo.png") 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "--pretrained_model_path", 71 | type=str, 72 | required=True, 73 | help="Path to pretrained text-to-image model.", 74 | ) 75 | parser.add_argument( 76 | "--model_weights", 77 | type=str, 78 | required=True, 79 | help="Path to converter weight.", 80 | ) 81 | parser.add_argument( 82 | "--image_path", 83 | type=str, 84 | required=True, 85 | help="Path to input image.", 86 | ) 87 | parser.add_argument( 88 | "--prompt", 89 | type=str, 90 | default='a apple', 91 | required=False, 92 | help="Prompt for text-to-image.", 93 | ) 94 | parser.add_argument( 95 | "--denoising_step", 96 | type=int, 97 | default=20, 98 | help="Denoising step.", 99 | ) 100 | 101 | args = parser.parse_args() 102 | main(args) 103 | -------------------------------------------------------------------------------- /inference_merge_large_depth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import os 5 | import sys 6 | from matplotlib import pyplot as plt 7 | import argparse 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 10 | 11 | from merge.pipeline.merge_transformer_flux import xFluxTransformer2DModel, MERGEFluxTransformerModel 12 | from merge.pipeline.pipeline_merge_flux import MERGEFluxPipeline 13 | 14 | cmap = plt.get_cmap('Spectral') 15 | 16 | 17 | def main(args): 18 | 19 | weight_dtype = torch.bfloat16 20 | 21 | fixed_transformer = xFluxTransformer2DModel.from_pretrained( 22 | args.pretrained_model_path, subfolder="transformer", torch_dtype=weight_dtype 23 | ) 24 | fixed_transformer.requires_grad_(False) 25 | 26 | depth_converter = xFluxTransformer2DModel.from_pretrained( 27 | args.model_weights, subfolder="depth_converters", torch_dtype=weight_dtype 28 | ) 29 | depth_converter.requires_grad_(False) 30 | 31 | 32 | merge_flux_transformer = MERGEFluxTransformerModel(fixed_transformer, depth_converter) 33 | del fixed_transformer, depth_converter 34 | 35 | model = MERGEFluxPipeline.from_pretrained( 36 | args.pretrained_model_path, 37 | transformer=merge_flux_transformer, 38 | torch_dtype=weight_dtype, 39 | ).to("cuda") 40 | 41 | # for depth estimation 42 | image = Image.open(args.image_path) 43 | width, height = image.size 44 | depth_image = model( 45 | prompt='', 46 | control_image=image, 47 | num_inference_steps=args.denoising_step, 48 | guidance_scale=0, 49 | max_sequence_length=512, 50 | output_type='pt', 51 | height=height, 52 | width=width, 53 | ).images 54 | depth_image = torch.mean(depth_image, dim=1).squeeze().to(torch.float32).cpu().numpy() 55 | depth_image = (cmap(depth_image) * 255).astype(np.uint8) 56 | Image.fromarray(depth_image).save("./merge_large_depth_demo.png") 57 | 58 | # for text-to-image 59 | image = model( 60 | prompt=args.prompt, 61 | height=1024, 62 | width=1024, 63 | num_inference_steps=args.denoising_step, 64 | guidance_scale=3.5, 65 | max_sequence_length=512, 66 | generator=torch.Generator("cpu").manual_seed(0), 67 | use_merge=False 68 | ).images[0] 69 | image.save("merge_large_t2i.png") 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | 74 | parser.add_argument( 75 | "--pretrained_model_path", 76 | type=str, 77 | required=True, 78 | help="Path to pretrained model or model identifier from huggingface.co/models.", 79 | ) 80 | parser.add_argument( 81 | "--model_weights", 82 | type=str, 83 | required=True, 84 | help="Path to converter weight.", 85 | ) 86 | parser.add_argument( 87 | "--image_path", 88 | type=str, 89 | required=True, 90 | help="Path to input image.", 91 | ) 92 | parser.add_argument( 93 | "--prompt", 94 | type=str, 95 | default="A cat holding a sign that says hello world", 96 | required=False, 97 | help="Prompt for text-to-image.", 98 | ) 99 | parser.add_argument( 100 | "--denoising_step", 101 | type=int, 102 | default=20, 103 | help="ensemble size of the model.", 104 | ) 105 | 106 | args = parser.parse_args() 107 | main(args) -------------------------------------------------------------------------------- /merge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/merge/__init__.py -------------------------------------------------------------------------------- /merge/pipeline/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torch import nn 16 | from diffusers.models.embeddings import get_2d_sincos_pos_embed, get_2d_rotary_pos_embed 17 | 18 | class PatchEmbed(nn.Module): 19 | """2D Image to Patch Embedding with support for SD3 cropping.""" 20 | 21 | def __init__( 22 | self, 23 | height=224, 24 | width=224, 25 | patch_size=16, 26 | in_channels=3, 27 | embed_dim=768, 28 | layer_norm=False, 29 | flatten=True, 30 | bias=True, 31 | interpolation_scale=1, 32 | pos_embed_type="sincos", 33 | pos_embed_max_size=None, # For SD3 cropping 34 | ): 35 | super().__init__() 36 | 37 | num_patches = (height // patch_size) * (width // patch_size) 38 | self.flatten = flatten 39 | self.layer_norm = layer_norm 40 | self.pos_embed_max_size = pos_embed_max_size 41 | 42 | self.proj = nn.Conv2d( 43 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 44 | ) 45 | if layer_norm: 46 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 47 | else: 48 | self.norm = None 49 | 50 | self.patch_size = patch_size 51 | self.height, self.width = height // patch_size, width // patch_size 52 | self.base_size = height // patch_size 53 | self.interpolation_scale = interpolation_scale 54 | 55 | # Calculate positional embeddings based on max size or default 56 | if pos_embed_max_size: 57 | grid_size = pos_embed_max_size 58 | else: 59 | grid_size = int(num_patches**0.5) 60 | 61 | if pos_embed_type is None: 62 | self.pos_embed = None 63 | elif pos_embed_type == "sincos": 64 | pos_embed = get_2d_sincos_pos_embed( 65 | embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale 66 | ) 67 | persistent = True if pos_embed_max_size else False 68 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) 69 | else: 70 | raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") 71 | 72 | def cropped_pos_embed(self, height, width): 73 | """Crops positional embeddings for SD3 compatibility.""" 74 | if self.pos_embed_max_size is None: 75 | raise ValueError("`pos_embed_max_size` must be set for cropping.") 76 | 77 | height = height // self.patch_size 78 | width = width // self.patch_size 79 | if height > self.pos_embed_max_size: 80 | raise ValueError( 81 | f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." 82 | ) 83 | if width > self.pos_embed_max_size: 84 | raise ValueError( 85 | f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." 86 | ) 87 | 88 | top = (self.pos_embed_max_size - height) // 2 89 | left = (self.pos_embed_max_size - width) // 2 90 | spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) 91 | spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] 92 | spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) 93 | return spatial_pos_embed 94 | 95 | def forward(self, latent): 96 | if self.pos_embed_max_size is not None: 97 | height, width = latent.shape[-2:] 98 | else: 99 | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size 100 | 101 | latent = self.proj(latent) 102 | if self.flatten: 103 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 104 | if self.layer_norm: 105 | latent = self.norm(latent) 106 | if self.pos_embed is None: 107 | return latent.to(latent.dtype) 108 | # Interpolate or crop positional embeddings as needed 109 | if self.pos_embed_max_size: 110 | pos_embed = self.cropped_pos_embed(height, width) 111 | else: 112 | if self.height != height or self.width != width: 113 | pos_embed = get_2d_sincos_pos_embed( 114 | embed_dim=self.pos_embed.shape[-1], 115 | grid_size=(height, width), 116 | base_size=self.base_size, 117 | interpolation_scale=self.interpolation_scale, 118 | ) 119 | pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) 120 | else: 121 | pos_embed = self.pos_embed 122 | 123 | return (latent + pos_embed).to(latent.dtype) 124 | -------------------------------------------------------------------------------- /merge/pipeline/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Optional 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from diffusers.models.attention import FeedForward 6 | from diffusers.models.activations import GELU, GEGLU, ApproximateGELU 7 | from diffusers.utils import deprecate 8 | from diffusers.models.normalization import PixArtAlphaCombinedTimestepSizeEmbeddings 9 | 10 | class MLP(nn.Module): 11 | """Very simple multi-layer perceptron (also called FFN)""" 12 | 13 | def __init__( 14 | self, 15 | input_dim, 16 | hidden_dim, 17 | output_dim, 18 | num_layers, 19 | sigmoid_output: bool = False, 20 | affine_func=nn.Linear, 21 | act_fn="relu" 22 | ): 23 | super().__init__() 24 | self.num_layers = num_layers 25 | h = [hidden_dim] * (num_layers - 1) 26 | self.layers = nn.ModuleList( 27 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 28 | ) 29 | if act_fn == "gelu_tanh": 30 | self.act = nn.GELU(approximate="tanh") 31 | elif act_fn == "gelu": 32 | self.act = nn.GELU() 33 | elif act_fn == "silu": 34 | self.act = nn.SiLU() 35 | elif act_fn == "relu": 36 | self.act = nn.ReLU() 37 | self.sigmoid_output = sigmoid_output 38 | 39 | def forward(self, x: torch.Tensor): 40 | for i, layer in enumerate(self.layers): 41 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 42 | if self.sigmoid_output: 43 | x = F.sigmoid(x) 44 | return x 45 | 46 | class ResidualMLP(nn.Module): 47 | 48 | def __init__( 49 | self, 50 | input_dim, 51 | hidden_dim, 52 | output_dim, 53 | num_mlp, 54 | num_layer_per_mlp, 55 | sigmoid_output: bool = False, 56 | affine_func=nn.Linear, 57 | act_fn="relu" 58 | ): 59 | super().__init__() 60 | self.num_mlp = num_mlp 61 | self.in2hidden_dim = affine_func(input_dim, hidden_dim) 62 | self.hidden2out_dim = affine_func(hidden_dim, output_dim) 63 | self.mlp_list = nn.ModuleList( 64 | MLP( 65 | hidden_dim, 66 | hidden_dim, 67 | hidden_dim, 68 | num_layer_per_mlp, 69 | affine_func=affine_func, 70 | act_fn=act_fn 71 | ) for _ in range(num_mlp) 72 | ) 73 | self.sigmoid_output = sigmoid_output 74 | 75 | def forward(self, x: torch.Tensor): 76 | x = self.in2hidden_dim(x) 77 | for mlp in self.mlp_list: 78 | out = mlp(x) 79 | x = x + out 80 | out = self.hidden2out_dim(x) 81 | return out 82 | 83 | 84 | class _FeedForward(nn.Module): 85 | def __init__( 86 | self, 87 | dim: int, 88 | dim_out: Optional[int] = None, 89 | mult: int = 4, 90 | dropout: float = 0.0, 91 | activation_fn: str = "geglu", 92 | final_dropout: bool = False, 93 | inner_dim=None, 94 | bias: bool = True, 95 | ): 96 | super().__init__() 97 | if inner_dim is None: 98 | inner_dim = int(dim * mult) 99 | dim_out = dim_out if dim_out is not None else dim 100 | 101 | if activation_fn == "gelu": 102 | act_fn = GELU(dim, inner_dim, bias=bias) 103 | if activation_fn == "gelu-approximate": 104 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 105 | elif activation_fn == "geglu": 106 | act_fn = GEGLU(dim, inner_dim, bias=bias) 107 | elif activation_fn == "geglu-approximate": 108 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 109 | 110 | self.net = nn.ModuleList([]) 111 | # project in 112 | self.net.append(act_fn) 113 | # project dropout 114 | self.net.append(nn.Dropout(dropout)) 115 | # project out 116 | self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) 117 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 118 | if final_dropout: 119 | self.net.append(nn.Dropout(dropout)) 120 | self.final_dropout = True 121 | else: 122 | self.final_dropout = False 123 | def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: 124 | if len(args) > 0 or kwargs.get("scale", None) is not None: 125 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 126 | deprecate("scale", "1.0.0", deprecation_message) 127 | hidden_states = self.net[0](hidden_states) 128 | hidden_states = self.net[1](hidden_states) 129 | hidden_states = self.net[2](hidden_states) 130 | 131 | if self.final_dropout: 132 | hidden_states = self.net[3](hidden_states) 133 | return hidden_states 134 | 135 | class _AdaLayerNormSingle(nn.Module): 136 | r""" 137 | Norm layer adaptive layer norm single (adaLN-single). 138 | 139 | As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). 140 | 141 | Parameters: 142 | embedding_dim (`int`): The size of each embedding vector. 143 | use_additional_conditions (`bool`): To use additional conditions for normalization or not. 144 | """ 145 | 146 | def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, timestep_dim: int = None): 147 | super().__init__() 148 | 149 | if timestep_dim is None: 150 | timestep_dim = embedding_dim 151 | 152 | self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( 153 | embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions 154 | ) 155 | 156 | self.silu = nn.SiLU() 157 | self.linear = nn.Linear(embedding_dim, 6 * timestep_dim, bias=True) 158 | 159 | def forward( 160 | self, 161 | timestep: torch.Tensor, 162 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 163 | batch_size: Optional[int] = None, 164 | hidden_dtype: Optional[torch.dtype] = None, 165 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 166 | # No modulation happening here. 167 | added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} 168 | embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) 169 | return self.linear(self.silu(embedded_timestep)), embedded_timestep -------------------------------------------------------------------------------- /merge/pipeline/merge_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, List, Union, Tuple 2 | import matplotlib 3 | from diffusers import UNet2DConditionModel, Transformer2DModel 4 | 5 | matplotlib.use('Agg') 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from diffusers.models import PixArtTransformer2DModel 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models.modeling_utils import ModelMixin 13 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection 14 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 15 | 16 | from .transformer_blocks import MERGETransformerBlock 17 | 18 | from merge.pipeline.layers import _AdaLayerNormSingle 19 | 20 | 21 | class xTransformerModel(ModelMixin, ConfigMixin): 22 | 23 | _supports_gradient_checkpointing = True 24 | _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] 25 | 26 | @register_to_config 27 | def __init__( 28 | self, 29 | num_attention_heads: int = 16, 30 | attention_head_dim: int = 72, 31 | in_channels: int = 4, 32 | out_channels: Optional[int] = 8, 33 | num_layers: int = 2, 34 | dropout: float = 0.0, 35 | norm_num_groups: int = 32, 36 | cross_attention_dim: Optional[int] = 1152, 37 | attention_bias: bool = True, 38 | sample_size: int = 64, 39 | patch_size: int = 2, 40 | activation_fn: str = "gelu-approximate", 41 | num_embeds_ada_norm: Optional[int] = 1000, 42 | upcast_attention: bool = False, 43 | norm_type: str = "ada_norm_single", 44 | norm_elementwise_affine: bool = False, 45 | norm_eps: float = 1e-6, 46 | interpolation_scale: Optional[int] = None, 47 | use_additional_conditions: Optional[bool] = None, 48 | caption_channels: Optional[int] = None, 49 | attention_type: Optional[str] = "default", 50 | is_converter: Optional[bool] = False, 51 | ff_mult: Optional[int] = 4, 52 | GRE: Optional[bool] = True, 53 | ): 54 | super().__init__() 55 | 56 | # Set some common variables used across the board. 57 | self.attention_head_dim = attention_head_dim 58 | self.inner_dim = 1152 59 | self.attn_inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 60 | self.out_channels = in_channels if out_channels is None else out_channels 61 | if use_additional_conditions is None: 62 | if sample_size == 128: 63 | use_additional_conditions = True 64 | else: 65 | use_additional_conditions = False 66 | self.use_additional_conditions = use_additional_conditions 67 | 68 | self.gradient_checkpointing = False 69 | 70 | # 2. Initialize the position embedding and transformer blocks. 71 | self.height = self.config.sample_size 72 | self.width = self.config.sample_size 73 | 74 | interpolation_scale = ( 75 | self.config.interpolation_scale 76 | if self.config.interpolation_scale is not None 77 | else max(self.config.sample_size // 64, 1) 78 | ) 79 | self.pos_embed = PatchEmbed( 80 | height=self.config.sample_size, 81 | width=self.config.sample_size, 82 | patch_size=self.config.patch_size, 83 | in_channels=self.config.in_channels, 84 | embed_dim=self.inner_dim, 85 | interpolation_scale=interpolation_scale, 86 | ) 87 | 88 | # 2. Initialize transformer blocks. 89 | self.transformer_blocks = nn.ModuleList( 90 | [ 91 | MERGETransformerBlock( 92 | self.inner_dim, 93 | self.config.num_attention_heads, 94 | self.config.attention_head_dim, 95 | dropout=self.config.dropout, 96 | cross_attention_dim=self.config.cross_attention_dim, 97 | activation_fn=self.config.activation_fn, 98 | num_embeds_ada_norm=self.config.num_embeds_ada_norm, 99 | attention_bias=self.config.attention_bias, 100 | upcast_attention=self.config.upcast_attention, 101 | norm_type=norm_type, 102 | norm_elementwise_affine=self.config.norm_elementwise_affine, 103 | norm_eps=self.config.norm_eps, 104 | attention_type=self.config.attention_type, 105 | ff_mult=self.config.ff_mult 106 | ) 107 | for _ in range(self.config.num_layers) 108 | ] 109 | ) 110 | 111 | # 3. Output blocks. 112 | self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) 113 | self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) 114 | self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) 115 | 116 | if is_converter: 117 | self.adaln_single = None 118 | self.caption_projection = None 119 | else: 120 | self.adaln_single = _AdaLayerNormSingle( 121 | self.inner_dim, 122 | use_additional_conditions=self.use_additional_conditions, 123 | ) 124 | 125 | if self.config.caption_channels is not None: 126 | self.caption_projection = PixArtAlphaTextProjection( 127 | in_features=self.config.caption_channels, hidden_size=self.inner_dim 128 | ) 129 | 130 | @classmethod 131 | def from_transformer( 132 | cls, 133 | transformer: PixArtTransformer2DModel, 134 | converter_init_type='pretrained', 135 | share_num=2, 136 | **kwargs 137 | ): 138 | xtransformer = xTransformerModel(**kwargs) 139 | 140 | source_state_dict = transformer.state_dict() 141 | target_state_dict = xtransformer.state_dict() 142 | 143 | # load pretrained param exclude transformer blocks 144 | for name, param in source_state_dict.items(): 145 | if 'transformer_blocks' not in name and name in target_state_dict and target_state_dict[ 146 | name].shape == param.shape: 147 | target_state_dict[name].data.copy_(param.data) 148 | 149 | # init converter's transformer block from fixed transformer 150 | if converter_init_type=='pretrained': 151 | pretrained_converter_id = list(range(0, xtransformer.config.num_layers, share_num)) 152 | source_state_dict = nn.ModuleList([ 153 | transformer.transformer_blocks[i] 154 | for i in pretrained_converter_id 155 | ]).state_dict() 156 | target_state_dict = xtransformer.transformer_blocks.state_dict() 157 | 158 | for name, param in source_state_dict.items(): 159 | if name in target_state_dict and target_state_dict[name].shape == param.shape: 160 | target_state_dict[name].data.copy_(param.data) 161 | 162 | return xtransformer 163 | 164 | def _replace_in_out_proj_conv(self): 165 | # replace the in_proj layer to accept 8 in_channels 166 | _in_weight = self.pos_embed.proj.weight.clone() # [320, 4, 3, 3] 167 | _in_bias = self.pos_embed.proj.bias.clone() # [320] 168 | _in_weight = _in_weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) 169 | # half the activation magnitude 170 | _in_weight *= 0.5 171 | # new conv_in channel 172 | _n_convin_out_channel = self.pos_embed.proj.out_channels 173 | _new_conv_in = nn.Conv2d( 174 | 8, _n_convin_out_channel, 175 | kernel_size=(self.config.patch_size, self.config.patch_size), 176 | stride=(self.config.patch_size, self.config.patch_size) 177 | ) 178 | _new_conv_in.weight = nn.Parameter(_in_weight) 179 | _new_conv_in.bias = nn.Parameter(_in_bias) 180 | self.pos_embed.proj = _new_conv_in 181 | 182 | self.register_to_config(in_channels=8) 183 | 184 | def get_trainable_params(self): 185 | trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 186 | print(f"trainable params: {trainable_params}") 187 | return trainable_params 188 | 189 | class MERGEPixArtTransformer(ModelMixin, ConfigMixin): 190 | def __init__( 191 | self, 192 | fixed_transformer: xTransformerModel, 193 | converter: xTransformerModel, 194 | training=False 195 | ): 196 | super().__init__() 197 | 198 | self.gradient_checkpointing = False 199 | self.register_to_config(**fixed_transformer.config) 200 | self.training = training 201 | 202 | self.fixed_transformer = fixed_transformer 203 | self.converter = converter 204 | 205 | self.mini_blocks_num = converter.config.num_layers 206 | 207 | def _set_gradient_checkpointing(self, module, value=False): 208 | if hasattr(module, "gradient_checkpointing"): 209 | module.gradient_checkpointing = value 210 | 211 | def get_input( 212 | self, 213 | transformer, 214 | hidden_states: torch.Tensor, 215 | encoder_hidden_states: Optional[torch.Tensor] = None, 216 | timestep: Optional[torch.LongTensor] = None, 217 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 218 | attention_mask: Optional[torch.Tensor] = None, 219 | encoder_attention_mask: Optional[torch.Tensor] = None, 220 | converter=None, 221 | ): 222 | if transformer.use_additional_conditions and added_cond_kwargs is None: 223 | raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") 224 | 225 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 226 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 227 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 228 | # expects mask of shape: 229 | # [batch, key_tokens] 230 | # adds singleton query_tokens dimension: 231 | # [batch, 1, key_tokens] 232 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 233 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 234 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 235 | if attention_mask is not None and attention_mask.ndim == 2: 236 | # assume that mask is expressed as: 237 | # (1 = keep, 0 = discard) 238 | # convert mask into a bias that can be added to attention scores: 239 | # (keep = +0, discard = -10000.0) 240 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 241 | attention_mask = attention_mask.unsqueeze(1) 242 | 243 | # 1. Input 244 | batch_size = hidden_states.shape[0] 245 | height, width = ( 246 | hidden_states.shape[-2] // transformer.config.patch_size, 247 | hidden_states.shape[-1] // transformer.config.patch_size, 248 | ) 249 | 250 | if converter is not None: 251 | hidden_states = converter.pos_embed(hidden_states) 252 | else: 253 | hidden_states = transformer.pos_embed(hidden_states) 254 | 255 | timestep, embedded_timestep = transformer.adaln_single( 256 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 257 | ) 258 | 259 | encoder_hidden_states = transformer.caption_projection(encoder_hidden_states) 260 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 261 | 262 | # 263 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 264 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 265 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 266 | 267 | return ( 268 | (height, width), 269 | hidden_states, 270 | attention_mask, 271 | encoder_hidden_states, 272 | encoder_attention_mask, 273 | timestep, 274 | embedded_timestep, 275 | ) 276 | 277 | def forward( 278 | self, 279 | dense_hidden_states: Optional[torch.Tensor] = None, 280 | image_hidden_states: Optional[torch.Tensor] = None, 281 | encoder_hidden_states: Optional[torch.Tensor] = None, 282 | timestep: Optional[torch.LongTensor] = None, 283 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 284 | cross_attention_kwargs: Dict[str, Any] = None, 285 | attention_mask: Optional[torch.Tensor] = None, 286 | encoder_attention_mask: Optional[torch.Tensor] = None, 287 | return_dict: bool = True, 288 | ): 289 | with torch.cuda.amp.autocast(): 290 | if self.training: 291 | assert dense_hidden_states is not None, \ 292 | f'Only dense_hidden_states for perception is supported during training.' 293 | output = self.train_forward( 294 | dense_hidden_states, 295 | encoder_hidden_states=encoder_hidden_states, 296 | timestep=timestep, 297 | added_cond_kwargs=added_cond_kwargs, 298 | encoder_attention_mask=encoder_attention_mask, 299 | return_dict=return_dict, 300 | ) 301 | else: 302 | assert (dense_hidden_states is not None) != (image_hidden_states is not None), \ 303 | f'Only one type of input is supported: ' \ 304 | f'image_hidden_states for generation, dense_hidden_states for perception.' 305 | output = self.test_forward( 306 | dense_hidden_states=dense_hidden_states, 307 | image_hidden_states=image_hidden_states, 308 | encoder_hidden_states=encoder_hidden_states, 309 | timestep=timestep, 310 | added_cond_kwargs=added_cond_kwargs, 311 | cross_attention_kwargs=cross_attention_kwargs, 312 | attention_mask=attention_mask, 313 | encoder_attention_mask=encoder_attention_mask, 314 | return_dict=return_dict, 315 | ) 316 | return output 317 | 318 | def train_forward( 319 | self, 320 | dense_hidden_states, 321 | encoder_hidden_states, 322 | timestep: Optional[torch.LongTensor] = None, 323 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 324 | cross_attention_kwargs: Dict[str, Any] = None, 325 | attention_mask: Optional[torch.Tensor] = None, 326 | encoder_attention_mask: Optional[torch.Tensor] = None, 327 | return_dict: bool = True, 328 | ): 329 | 330 | pecp_noise_output = self._pecp_forward( 331 | dense_hidden_states, 332 | encoder_hidden_states, 333 | timestep, 334 | added_cond_kwargs, 335 | cross_attention_kwargs, 336 | attention_mask, 337 | encoder_attention_mask, 338 | return_dict, 339 | ) 340 | return pecp_noise_output 341 | 342 | 343 | def test_forward( 344 | self, 345 | dense_hidden_states: Optional[torch.Tensor] = None, 346 | image_hidden_states: Optional[torch.Tensor] = None, 347 | encoder_hidden_states: Optional[torch.Tensor] = None, 348 | timestep: Optional[torch.LongTensor] = None, 349 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 350 | cross_attention_kwargs: Dict[str, Any] = None, 351 | attention_mask: Optional[torch.Tensor] = None, 352 | encoder_attention_mask: Optional[torch.Tensor] = None, 353 | return_dict: bool = True, 354 | ): 355 | if image_hidden_states is not None: 356 | gen_noise_output = self._gen_forward( 357 | image_hidden_states=image_hidden_states, 358 | encoder_hidden_states=encoder_hidden_states, 359 | timestep=timestep, 360 | added_cond_kwargs=added_cond_kwargs, 361 | cross_attention_kwargs=cross_attention_kwargs, 362 | attention_mask=attention_mask, 363 | encoder_attention_mask=encoder_attention_mask, 364 | return_dict=return_dict 365 | ) 366 | return gen_noise_output 367 | else: 368 | pecp_noise_output = self._pecp_forward( 369 | dense_hidden_states, 370 | encoder_hidden_states=encoder_hidden_states, 371 | timestep=timestep, 372 | added_cond_kwargs=added_cond_kwargs, 373 | cross_attention_kwargs=cross_attention_kwargs, 374 | attention_mask=attention_mask, 375 | encoder_attention_mask=encoder_attention_mask, 376 | return_dict=return_dict, 377 | ) 378 | return pecp_noise_output 379 | 380 | def _gen_forward( 381 | self, 382 | image_hidden_states, 383 | encoder_hidden_states: Optional[torch.Tensor] = None, 384 | timestep: Optional[torch.LongTensor] = None, 385 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 386 | cross_attention_kwargs: Dict[str, Any] = None, 387 | attention_mask: Optional[torch.Tensor] = None, 388 | encoder_attention_mask: Optional[torch.Tensor] = None, 389 | return_dict: bool = True, 390 | ): 391 | ( 392 | (height_image, width_image), 393 | hidden_states_image, 394 | attention_mask_image, 395 | encoder_hidden_states_image, 396 | encoder_attention_mask_image, 397 | timestep_image, 398 | embedded_timestep_image 399 | ) = self.get_input( 400 | self.fixed_transformer, 401 | image_hidden_states, 402 | encoder_hidden_states, 403 | timestep, 404 | added_cond_kwargs, 405 | attention_mask, 406 | encoder_attention_mask 407 | ) 408 | 409 | # 2. Blocks 410 | 411 | for block_index, transformer_block in enumerate(self.fixed_transformer.transformer_blocks): 412 | hidden_states_image = transformer_block( 413 | hidden_states_image, 414 | attention_mask=attention_mask_image, 415 | encoder_hidden_states=encoder_hidden_states_image, 416 | encoder_attention_mask=encoder_attention_mask_image, 417 | timestep=timestep_image, 418 | cross_attention_kwargs=cross_attention_kwargs, 419 | class_labels=None, 420 | ) 421 | 422 | # 3. Output 423 | gen_noise_output = self.output( 424 | self.fixed_transformer, 425 | hidden_states_image, 426 | embedded_timestep_image, 427 | height_image, 428 | width_image, 429 | return_dict 430 | ) 431 | 432 | return gen_noise_output[0] 433 | 434 | def _pecp_forward( 435 | self, 436 | dense_hidden_states: torch.Tensor, 437 | encoder_hidden_states: torch.Tensor, 438 | timestep: Optional[torch.LongTensor] = None, 439 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 440 | cross_attention_kwargs: Dict[str, Any] = None, 441 | attention_mask: Optional[torch.Tensor] = None, 442 | encoder_attention_mask: Optional[torch.Tensor] = None, 443 | return_dict: bool = True, 444 | ): 445 | 446 | ( 447 | (height_dense, width_dense), 448 | hidden_states_dense, 449 | attention_mask_dense, 450 | encoder_hidden_states_dense, 451 | encoder_attention_mask_dense, 452 | timestep, 453 | embedded_timestep, 454 | ) = self.get_input( 455 | self.fixed_transformer, 456 | dense_hidden_states, 457 | encoder_hidden_states, 458 | timestep, 459 | added_cond_kwargs, 460 | attention_mask, 461 | encoder_attention_mask, 462 | converter=self.converter 463 | ) 464 | 465 | # 2. Blocks 466 | num_group = self.converter.config.num_layers 467 | share_num = self.fixed_transformer.config.num_layers// num_group 468 | for block_index, transformer_block in enumerate(self.fixed_transformer.transformer_blocks): 469 | 470 | hidden_states_dense = self.converter.transformer_blocks[block_index//share_num]( 471 | hidden_states_dense, 472 | attention_mask=attention_mask_dense, 473 | encoder_hidden_states=encoder_hidden_states_dense, 474 | encoder_attention_mask=encoder_attention_mask_dense, 475 | timestep=timestep, 476 | cross_attention_kwargs=cross_attention_kwargs, 477 | class_labels=None, 478 | ) 479 | 480 | hidden_states_dense = transformer_block( 481 | hidden_states_dense, 482 | attention_mask=attention_mask_dense, 483 | encoder_hidden_states=encoder_hidden_states_dense, 484 | encoder_attention_mask=encoder_attention_mask_dense, 485 | timestep=timestep, 486 | cross_attention_kwargs=cross_attention_kwargs, 487 | class_labels=None, 488 | ) 489 | # 3. Output 490 | pecp_noise_output = self.output( 491 | self.converter, 492 | hidden_states_dense, 493 | embedded_timestep, 494 | height_dense, 495 | width_dense, 496 | return_dict 497 | ) 498 | 499 | return pecp_noise_output[0] 500 | 501 | 502 | def output( 503 | self, 504 | transformer, 505 | hidden_states, 506 | embedded_timestep, 507 | height, 508 | width, 509 | return_dict 510 | ): 511 | shift, scale = ( 512 | transformer.scale_shift_table[None] + embedded_timestep[:, None].to( 513 | transformer.scale_shift_table.device) 514 | ).chunk(2, dim=1) 515 | hidden_states = transformer.norm_out(hidden_states) 516 | # Modulation 517 | hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) 518 | hidden_states = transformer.proj_out(hidden_states) 519 | hidden_states = hidden_states.squeeze(1) 520 | 521 | # unpatchify 522 | hidden_states = hidden_states.reshape( 523 | shape=(-1, height, width, transformer.config.patch_size, transformer.config.patch_size, 524 | transformer.out_channels) 525 | ) 526 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 527 | output = hidden_states.reshape( 528 | shape=(-1, transformer.out_channels, height * transformer.config.patch_size, 529 | width * transformer.config.patch_size) 530 | ) 531 | 532 | if not return_dict: 533 | return (output,) 534 | 535 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /merge/pipeline/transformer_attentions.py: -------------------------------------------------------------------------------- 1 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2 | # See the License for the specific language governing permissions and 3 | # limitations under the License. 4 | from torch import einsum 5 | from typing import Callable, List, Optional, Union 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from einops import rearrange, repeat 12 | 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention_processor import Attention, SpatialNorm, AttnProcessor2_0, AttnProcessor 15 | from diffusers.models.attention import logger 16 | from diffusers.utils import deprecate 17 | 18 | class MERGEAttnProcessor: 19 | r""" 20 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 21 | """ 22 | 23 | def __init__(self): 24 | if not hasattr(F, "scaled_dot_product_attention"): 25 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 26 | 27 | def __call__( 28 | self, 29 | attn, 30 | hidden_states, 31 | encoder_hidden_states, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | temb: Optional[torch.Tensor] = None, 34 | *args, 35 | **kwargs, 36 | ) -> torch.Tensor: 37 | context = encoder_hidden_states 38 | if len(args) > 0 or kwargs.get("scale", None) is not None: 39 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 40 | deprecate("scale", "1.0.0", deprecation_message) 41 | 42 | residual = hidden_states 43 | if attn.spatial_norm is not None: 44 | hidden_states = attn.spatial_norm(hidden_states, temb) 45 | 46 | input_ndim = hidden_states.ndim 47 | 48 | if input_ndim == 4: 49 | batch_size, channel, height, width = hidden_states.shape 50 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 51 | 52 | batch_size, sequence_length, _ = ( 53 | hidden_states.shape if context is None else context.shape 54 | ) 55 | 56 | if attention_mask is not None: 57 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 58 | # scaled_dot_product_attention expects attention_mask shape to be 59 | # (batch, heads, source_length, target_length) 60 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 61 | 62 | if attn.group_norm is not None: 63 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 64 | 65 | query = attn.to_q(hidden_states) 66 | 67 | if context is None: 68 | context = hidden_states 69 | elif attn.norm_cross: 70 | context = attn.norm_context(context) 71 | 72 | key = attn.to_k(context) 73 | value = attn.to_v(context) 74 | 75 | inner_dim = key.shape[-1] 76 | head_dim = inner_dim // attn.heads 77 | 78 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 79 | 80 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 81 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 82 | 83 | hidden_states = F.scaled_dot_product_attention( 84 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 85 | ) 86 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 87 | hidden_states = hidden_states.to(query.dtype) 88 | 89 | # linear proj 90 | hidden_states = attn.to_out[0](hidden_states) 91 | # dropout 92 | hidden_states = attn.to_out[1](hidden_states) 93 | 94 | if input_ndim == 4: 95 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 96 | 97 | if attn.residual_connection: 98 | hidden_states = hidden_states + residual 99 | 100 | hidden_states = hidden_states / attn.rescale_output_factor 101 | 102 | return hidden_states -------------------------------------------------------------------------------- /merge/pipeline/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, List 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | from diffusers.models.attention import logger, _chunked_feed_forward, BasicTransformerBlock 10 | from diffusers.models.attention_processor import Attention 11 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous 12 | from .transformer_attentions import MERGEAttnProcessor 13 | 14 | from merge.pipeline.layers import _FeedForward 15 | 16 | class MERGETransformerBlock(BasicTransformerBlock): 17 | def __init__( 18 | self, 19 | dim: int, 20 | num_attention_heads: int, 21 | attention_head_dim: int, 22 | inner_dim: Optional[int] = None, 23 | dropout=0.0, 24 | cross_attention_dim: Optional[int] = None, 25 | activation_fn: str = "geglu", 26 | num_embeds_ada_norm: Optional[int] = None, 27 | attention_bias: bool = False, 28 | only_cross_attention: bool = False, 29 | double_self_attention: bool = False, 30 | upcast_attention: bool = False, 31 | norm_elementwise_affine: bool = True, 32 | norm_type: str = "layer_norm", 33 | norm_eps: float = 1e-5, 34 | final_dropout: bool = False, 35 | attention_type: str = "default", 36 | positional_embeddings: Optional[str] = None, 37 | num_positional_embeddings: Optional[int] = None, 38 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 39 | ada_norm_bias: Optional[int] = None, 40 | ff_inner_dim: Optional[int] = None, 41 | ff_mult: Optional[int] = 4, 42 | ff_bias: bool = True, 43 | attention_out_bias: bool = True, 44 | ): 45 | super().__init__( 46 | dim, 47 | num_attention_heads, 48 | attention_head_dim, 49 | dropout, 50 | cross_attention_dim, 51 | activation_fn, 52 | num_embeds_ada_norm, 53 | attention_bias, 54 | only_cross_attention, 55 | double_self_attention, 56 | upcast_attention, 57 | norm_elementwise_affine, 58 | norm_type, 59 | norm_eps, 60 | final_dropout, 61 | attention_type, 62 | positional_embeddings, 63 | num_positional_embeddings, 64 | ada_norm_continous_conditioning_embedding_dim, 65 | ada_norm_bias, 66 | ff_inner_dim, 67 | ff_bias, 68 | attention_out_bias, 69 | ) 70 | 71 | 72 | self.attn1 = Attention( 73 | query_dim=dim, 74 | heads=num_attention_heads, 75 | dim_head=attention_head_dim, 76 | dropout=dropout, 77 | bias=attention_bias, 78 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 79 | upcast_attention=upcast_attention, 80 | out_bias=attention_out_bias 81 | ) 82 | 83 | # 2. Cross-Attn 84 | if cross_attention_dim is not None or double_self_attention: 85 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 86 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 87 | # the second cross attention block. 88 | if norm_type == "ada_norm": 89 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 90 | elif norm_type == "ada_norm_continuous": 91 | self.norm2 = AdaLayerNormContinuous( 92 | dim, 93 | ada_norm_continous_conditioning_embedding_dim, 94 | norm_elementwise_affine, 95 | norm_eps, 96 | ada_norm_bias, 97 | "rms_norm", 98 | ) 99 | else: 100 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 101 | 102 | self.attn2 = Attention( 103 | query_dim=dim, 104 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 105 | heads=num_attention_heads, 106 | dim_head=attention_head_dim, 107 | dropout=dropout, 108 | bias=attention_bias, 109 | upcast_attention=upcast_attention, 110 | out_bias=attention_out_bias, 111 | processor=MERGEAttnProcessor() 112 | ) # is self-attn if encoder_hidden_states is none 113 | else: 114 | # merge revise 115 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 116 | self.attn2 = None 117 | 118 | self.ff = _FeedForward( 119 | dim, 120 | mult=ff_mult, 121 | dropout=dropout, 122 | activation_fn=activation_fn, 123 | final_dropout=final_dropout, 124 | inner_dim=ff_inner_dim, 125 | bias=ff_bias, 126 | ) 127 | 128 | # 5. Scale-shift for PixArt-Alpha. 129 | if norm_type == "ada_norm_single": 130 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5) 131 | 132 | # let chunk size default to None 133 | self._chunk_size = None 134 | self._chunk_dim = 0 135 | 136 | def forward( 137 | self, 138 | hidden_states: torch.Tensor, 139 | attention_mask: Optional[torch.Tensor] = None, 140 | encoder_hidden_states: Optional[torch.Tensor] = None, 141 | encoder_attention_mask: Optional[torch.Tensor] = None, 142 | timestep: Optional[torch.LongTensor] = None, 143 | cross_attention_kwargs: Dict[str, Any] = None, 144 | class_labels: Optional[torch.LongTensor] = None, 145 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 146 | ) -> torch.Tensor: 147 | if cross_attention_kwargs is not None: 148 | if cross_attention_kwargs.get("scale", None) is not None: 149 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") 150 | 151 | # Notice that normalization is always applied before the real computation in the following blocks. 152 | # 0. Self-Attention 153 | batch_size = hidden_states.shape[0] 154 | 155 | if self.norm_type == "ada_norm": 156 | norm_hidden_states = self.norm1(hidden_states, timestep) 157 | elif self.norm_type == "ada_norm_zero": 158 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 159 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 160 | ) 161 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 162 | norm_hidden_states = self.norm1(hidden_states) 163 | elif self.norm_type == "ada_norm_continuous": 164 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 165 | elif self.norm_type == "ada_norm_single": 166 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 167 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 168 | ).chunk(6, dim=1) 169 | norm_hidden_states = self.norm1(hidden_states) 170 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 171 | norm_hidden_states = norm_hidden_states.squeeze(1) 172 | else: 173 | raise ValueError("Incorrect norm used") 174 | 175 | if self.pos_embed is not None: 176 | norm_hidden_states = self.pos_embed(norm_hidden_states) 177 | 178 | # 1. Prepare GLIGEN inputs 179 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 180 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 181 | 182 | attn_output = self.attn1( 183 | norm_hidden_states, 184 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 185 | attention_mask=attention_mask, 186 | **cross_attention_kwargs, 187 | ) 188 | if self.norm_type == "ada_norm_zero": 189 | attn_output = gate_msa.unsqueeze(1) * attn_output 190 | elif self.norm_type == "ada_norm_single": 191 | attn_output = gate_msa * attn_output 192 | 193 | hidden_states = attn_output + hidden_states 194 | if hidden_states.ndim == 4: 195 | hidden_states = hidden_states.squeeze(1) 196 | 197 | # 1.2 GLIGEN Control 198 | if gligen_kwargs is not None: 199 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 200 | 201 | # 3. Cross-Attention 202 | if self.attn2 is not None: 203 | if self.norm_type == "ada_norm": 204 | norm_hidden_states = self.norm2(hidden_states, timestep) 205 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 206 | norm_hidden_states = self.norm2(hidden_states) 207 | elif self.norm_type == "ada_norm_single": 208 | # For PixArt norm2 isn't applied here: 209 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 210 | norm_hidden_states = hidden_states 211 | elif self.norm_type == "ada_norm_continuous": 212 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 213 | else: 214 | raise ValueError("Incorrect norm") 215 | 216 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 217 | norm_hidden_states = self.pos_embed(norm_hidden_states) 218 | 219 | attn_output = self.attn2( 220 | norm_hidden_states, 221 | encoder_hidden_states=encoder_hidden_states, 222 | attention_mask=encoder_attention_mask, 223 | **cross_attention_kwargs, 224 | ) 225 | hidden_states = attn_output + hidden_states 226 | 227 | # 4. Feed-forward 228 | # i2vgen doesn't have this norm 🤷‍♂️ 229 | if self.norm_type == "ada_norm_continuous": 230 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 231 | elif not self.norm_type == "ada_norm_single": 232 | norm_hidden_states = self.norm3(hidden_states) 233 | 234 | if self.norm_type == "ada_norm_zero": 235 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 236 | 237 | if self.norm_type == "ada_norm_single": 238 | norm_hidden_states = self.norm2(hidden_states) 239 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 240 | 241 | if self._chunk_size is not None: 242 | # "feed_forward_chunk_size" can be used to save memory 243 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 244 | else: 245 | ff_output = self.ff(norm_hidden_states) 246 | 247 | if self.norm_type == "ada_norm_zero": 248 | ff_output = gate_mlp.unsqueeze(1) * ff_output 249 | elif self.norm_type == "ada_norm_single": 250 | ff_output = gate_mlp * ff_output 251 | 252 | hidden_states = ff_output + hidden_states 253 | if hidden_states.ndim == 4: 254 | hidden_states = hidden_states.squeeze(1) 255 | 256 | hidden_states = hidden_states.clip(-65504, 65504) 257 | return hidden_states 258 | -------------------------------------------------------------------------------- /merge/pipeline/util/batchsize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | import torch 22 | import math 23 | 24 | 25 | # Search table for suggested max. inference batch size 26 | bs_search_table = [ 27 | # tested on A100-PCIE-80GB 28 | {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, 29 | {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, 30 | # tested on A100-PCIE-40GB 31 | {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, 32 | {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, 33 | {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, 34 | {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, 35 | # tested on RTX3090, RTX4090 36 | {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, 37 | {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, 38 | {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, 39 | {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, 40 | {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, 41 | {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, 42 | # tested on GTX1080Ti 43 | {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, 44 | {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, 45 | {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, 46 | {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, 47 | {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, 48 | ] 49 | 50 | 51 | def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: 52 | """ 53 | Automatically search for suitable operating batch size. 54 | 55 | Args: 56 | ensemble_size (`int`): 57 | Number of predictions to be ensembled. 58 | input_res (`int`): 59 | Operating resolution of the input image. 60 | 61 | Returns: 62 | `int`: Operating batch size. 63 | """ 64 | if not torch.cuda.is_available(): 65 | return 1 66 | 67 | total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 68 | filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] 69 | for settings in sorted( 70 | filtered_bs_search_table, 71 | key=lambda k: (k["res"], -k["total_vram"]), 72 | ): 73 | if input_res <= settings["res"] and total_vram >= settings["total_vram"]: 74 | bs = settings["bs"] 75 | if bs > ensemble_size: 76 | bs = ensemble_size 77 | elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: 78 | bs = math.ceil(ensemble_size / 2) 79 | return bs 80 | 81 | return 1 82 | -------------------------------------------------------------------------------- /merge/pipeline/util/ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | from functools import partial 22 | from typing import Optional, Tuple 23 | 24 | import numpy as np 25 | import torch 26 | 27 | from .image_util import get_tv_resample_method, resize_max_res 28 | 29 | 30 | def inter_distances(tensors: torch.Tensor): 31 | """ 32 | To calculate the distance between each two depth maps. 33 | """ 34 | distances = [] 35 | for i, j in torch.combinations(torch.arange(tensors.shape[0])): 36 | arr1 = tensors[i : i + 1] 37 | arr2 = tensors[j : j + 1] 38 | distances.append(arr1 - arr2) 39 | dist = torch.concatenate(distances, dim=0) 40 | return dist 41 | 42 | 43 | def ensemble_depth( 44 | depth: torch.Tensor, 45 | scale_invariant: bool = True, 46 | shift_invariant: bool = True, 47 | output_uncertainty: bool = False, 48 | reduction: str = "median", 49 | regularizer_strength: float = 0.02, 50 | max_iter: int = 2, 51 | tol: float = 1e-3, 52 | max_res: int = 1024, 53 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 54 | """ 55 | Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the 56 | number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for 57 | depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The 58 | alignment happens when the predictions have one or more degrees of freedom, that is when they are either 59 | affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only 60 | `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) 61 | alignment is skipped and only ensembling is performed. 62 | 63 | Args: 64 | depth (`torch.Tensor`): 65 | Input ensemble depth maps. 66 | scale_invariant (`bool`, *optional*, defaults to `True`): 67 | Whether to treat predictions as scale-invariant. 68 | shift_invariant (`bool`, *optional*, defaults to `True`): 69 | Whether to treat predictions as shift-invariant. 70 | output_uncertainty (`bool`, *optional*, defaults to `False`): 71 | Whether to output uncertainty map. 72 | reduction (`str`, *optional*, defaults to `"median"`): 73 | Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and 74 | `"median"`. 75 | regularizer_strength (`float`, *optional*, defaults to `0.02`): 76 | Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. 77 | max_iter (`int`, *optional*, defaults to `2`): 78 | Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` 79 | argument. 80 | tol (`float`, *optional*, defaults to `1e-3`): 81 | Alignment solver tolerance. The solver stops when the tolerance is reached. 82 | max_res (`int`, *optional*, defaults to `1024`): 83 | Resolution at which the alignment is performed; `None` matches the `processing_resolution`. 84 | Returns: 85 | A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: 86 | `(1, 1, H, W)`. 87 | """ 88 | if depth.dim() != 4 or depth.shape[1] != 1: 89 | raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") 90 | if reduction not in ("mean", "median"): 91 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 92 | if not scale_invariant and shift_invariant: 93 | raise ValueError("Pure shift-invariant ensembling is not supported.") 94 | 95 | def init_param(depth: torch.Tensor): 96 | init_min = depth.reshape(ensemble_size, -1).min(dim=1).values 97 | init_max = depth.reshape(ensemble_size, -1).max(dim=1).values 98 | 99 | if scale_invariant and shift_invariant: 100 | init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) 101 | init_t = -init_s * init_min 102 | param = torch.cat((init_s, init_t)).cpu().numpy() 103 | elif scale_invariant: 104 | init_s = 1.0 / init_max.clamp(min=1e-6) 105 | param = init_s.cpu().numpy() 106 | else: 107 | raise ValueError("Unrecognized alignment.") 108 | 109 | return param 110 | 111 | def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: 112 | if scale_invariant and shift_invariant: 113 | s, t = np.split(param, 2) 114 | s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) 115 | t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) 116 | out = depth * s + t 117 | elif scale_invariant: 118 | s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) 119 | out = depth * s 120 | else: 121 | raise ValueError("Unrecognized alignment.") 122 | return out 123 | 124 | def ensemble( 125 | depth_aligned: torch.Tensor, return_uncertainty: bool = False 126 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 127 | uncertainty = None 128 | if reduction == "mean": 129 | prediction = torch.mean(depth_aligned, dim=0, keepdim=True) 130 | if return_uncertainty: 131 | uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) 132 | elif reduction == "median": 133 | prediction = torch.median(depth_aligned, dim=0, keepdim=True).values 134 | if return_uncertainty: 135 | uncertainty = torch.median( 136 | torch.abs(depth_aligned - prediction), dim=0, keepdim=True 137 | ).values 138 | else: 139 | raise ValueError(f"Unrecognized reduction method: {reduction}.") 140 | return prediction, uncertainty 141 | 142 | def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: 143 | cost = 0.0 144 | depth_aligned = align(depth, param) 145 | 146 | for i, j in torch.combinations(torch.arange(ensemble_size)): 147 | diff = depth_aligned[i] - depth_aligned[j] 148 | cost += (diff**2).mean().sqrt().item() 149 | 150 | if regularizer_strength > 0: 151 | prediction, _ = ensemble(depth_aligned, return_uncertainty=False) 152 | err_near = (0.0 - prediction.min()).abs().item() 153 | err_far = (1.0 - prediction.max()).abs().item() 154 | cost += (err_near + err_far) * regularizer_strength 155 | 156 | return cost 157 | 158 | def compute_param(depth: torch.Tensor): 159 | import scipy 160 | 161 | depth_to_align = depth.to(torch.float32) 162 | if max_res is not None and max(depth_to_align.shape[2:]) > max_res: 163 | depth_to_align = resize_max_res( 164 | depth_to_align, max_res, get_tv_resample_method("nearest-exact") 165 | ) 166 | 167 | param = init_param(depth_to_align) 168 | 169 | res = scipy.optimize.minimize( 170 | partial(cost_fn, depth=depth_to_align), 171 | param, 172 | method="BFGS", 173 | tol=tol, 174 | options={"maxiter": max_iter, "disp": False}, 175 | ) 176 | 177 | return res.x 178 | 179 | requires_aligning = scale_invariant or shift_invariant 180 | ensemble_size = depth.shape[0] 181 | 182 | if requires_aligning: 183 | param = compute_param(depth) 184 | depth = align(depth, param) 185 | 186 | depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) 187 | 188 | depth_max = depth.max() 189 | if scale_invariant and shift_invariant: 190 | depth_min = depth.min() 191 | elif scale_invariant: 192 | depth_min = 0 193 | else: 194 | raise ValueError("Unrecognized alignment.") 195 | depth_range = (depth_max - depth_min).clamp(min=1e-6) 196 | depth = (depth - depth_min) / depth_range 197 | if output_uncertainty: 198 | uncertainty /= depth_range 199 | 200 | return depth, uncertainty # [1,1,H,W], [1,1,H,W] 201 | -------------------------------------------------------------------------------- /merge/pipeline/util/image_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # Last modified: 2024-05-24 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # -------------------------------------------------------------------------- 16 | # If you find this code useful, we kindly ask you to cite our paper in your work. 17 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 18 | # More information about the method can be found at https://marigoldmonodepth.github.io 19 | # -------------------------------------------------------------------------- 20 | 21 | 22 | import matplotlib 23 | import numpy as np 24 | import torch 25 | from torchvision.transforms import InterpolationMode 26 | from torchvision.transforms.functional import resize 27 | 28 | 29 | def colorize_depth_maps( 30 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 31 | ): 32 | """ 33 | Colorize depth maps. 34 | """ 35 | assert len(depth_map.shape) >= 2, "Invalid dimension" 36 | 37 | if isinstance(depth_map, torch.Tensor): 38 | depth = depth_map.detach().squeeze().numpy() 39 | elif isinstance(depth_map, np.ndarray): 40 | depth = depth_map.copy().squeeze() 41 | # reshape to [ (B,) H, W ] 42 | if depth.ndim < 3: 43 | depth = depth[np.newaxis, :, :] 44 | 45 | # colorize 46 | cm = matplotlib.colormaps[cmap] 47 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 48 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 49 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 50 | 51 | if valid_mask is not None: 52 | if isinstance(depth_map, torch.Tensor): 53 | valid_mask = valid_mask.detach().numpy() 54 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 55 | if valid_mask.ndim < 3: 56 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 57 | else: 58 | valid_mask = valid_mask[:, np.newaxis, :, :] 59 | valid_mask = np.repeat(valid_mask, 3, axis=1) 60 | img_colored_np[~valid_mask] = 0 61 | 62 | if isinstance(depth_map, torch.Tensor): 63 | img_colored = torch.from_numpy(img_colored_np).float() 64 | elif isinstance(depth_map, np.ndarray): 65 | img_colored = img_colored_np 66 | 67 | return img_colored 68 | 69 | 70 | def chw2hwc(chw): 71 | assert 3 == len(chw.shape) 72 | if isinstance(chw, torch.Tensor): 73 | hwc = torch.permute(chw, (1, 2, 0)) 74 | elif isinstance(chw, np.ndarray): 75 | hwc = np.moveaxis(chw, 0, -1) 76 | return hwc 77 | 78 | 79 | def resize_max_res( 80 | img: torch.Tensor, 81 | max_edge_resolution: int, 82 | resample_method: InterpolationMode = InterpolationMode.BILINEAR, 83 | ) -> torch.Tensor: 84 | """ 85 | Resize image to limit maximum edge length while keeping aspect ratio. 86 | 87 | Args: 88 | img (`torch.Tensor`): 89 | Image tensor to be resized. Expected shape: [B, C, H, W] 90 | max_edge_resolution (`int`): 91 | Maximum edge length (pixel). 92 | resample_method (`PIL.Image.Resampling`): 93 | Resampling method used to resize images. 94 | 95 | Returns: 96 | `torch.Tensor`: Resized image. 97 | """ 98 | assert 4 == img.dim(), f"Invalid input shape {img.shape}" 99 | 100 | original_height, original_width = img.shape[-2:] 101 | downscale_factor = min( 102 | max_edge_resolution / original_width, max_edge_resolution / original_height 103 | ) 104 | 105 | new_width = int(original_width * downscale_factor) 106 | new_height = int(original_height * downscale_factor) 107 | 108 | resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) 109 | return resized_img 110 | 111 | 112 | def get_tv_resample_method(method_str: str) -> InterpolationMode: 113 | resample_method_dict = { 114 | "bilinear": InterpolationMode.BILINEAR, 115 | "bicubic": InterpolationMode.BICUBIC, 116 | "nearest": InterpolationMode.NEAREST_EXACT, 117 | "nearest-exact": InterpolationMode.NEAREST_EXACT, 118 | } 119 | resample_method = resample_method_dict.get(method_str, None) 120 | if resample_method is None: 121 | raise ValueError(f"Unknown resampling method: {resample_method}") 122 | else: 123 | return resample_method 124 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mkl==2024.0 2 | mmcv==1.7.0 3 | diffusers==0.32.1 4 | timm==1.0.12 5 | accelerate==1.2.1 6 | tensorboard==2.17.0 7 | transformers==4.47.1 8 | ftfy==6.3.1 9 | protobuf==3.20.2 10 | gradio==4.1.1 11 | yapf==0.40.1 12 | bs4==0.0.2 13 | einops==0.8.0 14 | optimum==1.23.3 15 | scipy==1.13.1 16 | Pillow==10.4.0 17 | matplotlib==3.9.4 18 | omegaconf 19 | datasets 20 | sentencepiece~=0.1.99 21 | peft==0.14.0 22 | beautifulsoup4==4.12.3 23 | bitsandbytes 24 | wandb 25 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/H-EmbodVis/MERGE/93e81be69663b00175a18f405f9239e7acae1ba8/src/__init__.py -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-04-16 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import os 24 | 25 | from .base_depth_dataset import BaseDepthDataset, get_pred_name, DatasetMode # noqa: F401 26 | from .diode_dataset import DIODEDataset 27 | from .eth3d_dataset import ETH3DDataset 28 | from .hypersim_dataset import HypersimDataset 29 | from .kitti_dataset import KITTIDataset 30 | from .nyu_dataset import NYUDataset 31 | from .scannet_dataset import ScanNetDataset 32 | from .vkitti_dataset import VirtualKITTIDataset 33 | from .tartanair_dataset import TartanairDataset 34 | 35 | 36 | dataset_name_class_dict = { 37 | "hypersim": HypersimDataset, 38 | "vkitti": VirtualKITTIDataset, 39 | "nyu_v2": NYUDataset, 40 | "kitti": KITTIDataset, 41 | "eth3d": ETH3DDataset, 42 | "diode": DIODEDataset, 43 | "scannet": ScanNetDataset, 44 | } 45 | 46 | 47 | def get_dataset( 48 | cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs 49 | ) -> BaseDepthDataset: 50 | if "mixed" == cfg_data_split.name: 51 | assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." 52 | dataset_ls = [ 53 | get_dataset(_cfg, base_data_dir, mode, **kwargs) 54 | for _cfg in cfg_data_split.dataset_list 55 | ] 56 | return dataset_ls 57 | elif cfg_data_split.name in dataset_name_class_dict.keys(): 58 | dataset_class = dataset_name_class_dict[cfg_data_split.name] 59 | dataset = dataset_class( 60 | mode=mode, 61 | filename_ls_path=cfg_data_split.filenames, 62 | dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), 63 | **cfg_data_split, 64 | **kwargs, 65 | ) 66 | else: 67 | raise NotImplementedError 68 | 69 | return dataset 70 | -------------------------------------------------------------------------------- /src/dataset/base_depth_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-04-30 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import io 24 | import os 25 | import random 26 | import tarfile 27 | from enum import Enum 28 | from typing import Union 29 | 30 | import numpy as np 31 | import torch 32 | from PIL import Image 33 | from torch.utils.data import Dataset 34 | from torchvision.transforms import InterpolationMode, Resize 35 | 36 | from src.util.depth_transform import DepthNormalizerBase 37 | 38 | 39 | class DatasetMode(Enum): 40 | RGB_ONLY = "rgb_only" 41 | EVAL = "evaluate" 42 | TRAIN = "train" 43 | 44 | 45 | class DepthFileNameMode(Enum): 46 | """Prediction file naming modes""" 47 | 48 | id = 1 # id.png 49 | rgb_id = 2 # rgb_id.png 50 | i_d_rgb = 3 # i_d_1_rgb.png 51 | rgb_i_d = 4 52 | 53 | 54 | def read_image_from_tar(tar_obj, img_rel_path): 55 | image = tar_obj.extractfile("./" + img_rel_path) 56 | image = image.read() 57 | image = Image.open(io.BytesIO(image)) 58 | 59 | def default_rgb_transform(x): 60 | return x / 255.0 * 2 - 1 61 | 62 | 63 | class BaseDepthDataset(Dataset): 64 | def __init__( 65 | self, 66 | mode: DatasetMode, 67 | filename_ls_path: str, 68 | dataset_dir: str, 69 | disp_name: str, 70 | min_depth: float, 71 | max_depth: float, 72 | has_filled_depth: bool, 73 | name_mode: DepthFileNameMode, 74 | depth_transform: Union[DepthNormalizerBase, None] = None, 75 | augmentation_args: dict = None, 76 | resize_to_hw=None, 77 | move_invalid_to_far_plane: bool = True, 78 | rgb_transform=default_rgb_transform, # [0, 255] -> [-1, 1], 79 | **kwargs, 80 | ) -> None: 81 | super().__init__() 82 | self.mode = mode 83 | # dataset info 84 | self.filename_ls_path = filename_ls_path 85 | self.dataset_dir = dataset_dir 86 | assert os.path.exists( 87 | self.dataset_dir 88 | ), f"Dataset does not exist at: {self.dataset_dir}" 89 | self.disp_name = disp_name 90 | self.has_filled_depth = has_filled_depth 91 | self.name_mode: DepthFileNameMode = name_mode 92 | self.min_depth = min_depth 93 | self.max_depth = max_depth 94 | 95 | #eval inf 96 | self.alignment_max_res = None 97 | self.processing_res = None 98 | 99 | # training arguments 100 | self.depth_transform: DepthNormalizerBase = depth_transform 101 | self.augm_args = augmentation_args 102 | self.resize_to_hw = resize_to_hw 103 | self.rgb_transform = rgb_transform 104 | self.move_invalid_to_far_plane = move_invalid_to_far_plane 105 | 106 | # Load filenames 107 | with open(self.filename_ls_path, "r") as f: 108 | self.filenames = [ 109 | s.split() for s in f.readlines() 110 | ] # [['rgb.png', 'depth.tif'], [], ...] 111 | 112 | # Tar dataset 113 | self.tar_obj = None 114 | self.is_tar = ( 115 | True 116 | if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) 117 | else False 118 | ) 119 | 120 | def __len__(self): 121 | return len(self.filenames) 122 | 123 | def __getitem__(self, index): 124 | rasters, other = self._get_data_item(index) 125 | 126 | if DatasetMode.TRAIN == self.mode: 127 | rasters = self._training_preprocess(rasters) 128 | # merge 129 | outputs = rasters 130 | outputs.update(other) 131 | return outputs 132 | 133 | def _get_data_item(self, index): 134 | rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) 135 | 136 | rasters = {} 137 | 138 | # RGB data 139 | rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) 140 | 141 | # Depth data 142 | if DatasetMode.RGB_ONLY != self.mode: 143 | # load data 144 | depth_data = self._load_depth_data( 145 | depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path 146 | ) 147 | rasters.update(depth_data) 148 | # valid mask 149 | rasters["valid_mask_raw"] = self._get_valid_mask( 150 | rasters["depth_raw_linear"] 151 | ).clone() 152 | rasters["valid_mask_filled"] = self._get_valid_mask( 153 | rasters["depth_filled_linear"] 154 | ).clone() 155 | 156 | #return the range of depth value of the dataset 157 | other = { 158 | "index": index, 159 | "rgb_relative_path": rgb_rel_path, 160 | } 161 | 162 | return rasters, other 163 | 164 | def _load_rgb_data(self, rgb_rel_path): 165 | # Read RGB data 166 | rgb = self._read_rgb_file(rgb_rel_path) 167 | rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] 168 | 169 | outputs = { 170 | "rgb_int": torch.from_numpy(rgb).int(), 171 | "rgb_norm": torch.from_numpy(rgb_norm).float(), 172 | } 173 | return outputs 174 | 175 | def _load_depth_data(self, depth_rel_path, filled_rel_path): 176 | # Read depth data 177 | outputs = {} 178 | depth_raw = self._read_depth_file(depth_rel_path).squeeze() 179 | depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] 180 | outputs["depth_raw_linear"] = depth_raw_linear.clone() 181 | 182 | if self.has_filled_depth: 183 | depth_filled = self._read_depth_file(filled_rel_path).squeeze() 184 | depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) 185 | outputs["depth_filled_linear"] = depth_filled_linear 186 | else: 187 | outputs["depth_filled_linear"] = depth_raw_linear.clone() 188 | 189 | return outputs 190 | 191 | def _get_data_path(self, index): 192 | filename_line = self.filenames[index] 193 | 194 | # Get data path 195 | rgb_rel_path = filename_line[0] 196 | 197 | depth_rel_path, filled_rel_path = None, None 198 | if DatasetMode.RGB_ONLY != self.mode: 199 | depth_rel_path = filename_line[1] 200 | if self.has_filled_depth: 201 | filled_rel_path = filename_line[2] 202 | return rgb_rel_path, depth_rel_path, filled_rel_path 203 | 204 | def _read_image(self, img_rel_path) -> np.ndarray: 205 | if self.is_tar: 206 | if self.tar_obj is None: 207 | self.tar_obj = tarfile.open(self.dataset_dir) 208 | image_to_read = self.tar_obj.extractfile("./" + img_rel_path) 209 | image_to_read = image_to_read.read() 210 | image_to_read = io.BytesIO(image_to_read) 211 | else: 212 | image_to_read = os.path.join(self.dataset_dir, img_rel_path) 213 | image = Image.open(image_to_read) # [H, W, rgb] 214 | image = np.asarray(image) 215 | return image 216 | 217 | def _read_rgb_file(self, rel_path) -> np.ndarray: 218 | rgb = self._read_image(rel_path) 219 | rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] 220 | return rgb 221 | 222 | def _read_depth_file(self, rel_path): 223 | depth_in = self._read_image(rel_path) 224 | # Replace code below to decode depth according to dataset definition 225 | depth_decoded = depth_in 226 | 227 | return depth_decoded 228 | 229 | def _get_valid_mask(self, depth: torch.Tensor): 230 | valid_mask = torch.logical_and( 231 | (depth > self.min_depth), (depth < self.max_depth) 232 | ).bool() 233 | return valid_mask 234 | 235 | def _training_preprocess(self, rasters): 236 | # Augmentation 237 | if self.augm_args is not None: 238 | rasters = self._augment_data(rasters) 239 | 240 | # Normalization 241 | rasters["depth_raw_norm"] = self.depth_transform( 242 | rasters["depth_raw_linear"], rasters["valid_mask_raw"] 243 | ).clone() 244 | rasters["depth_filled_norm"] = self.depth_transform( 245 | rasters["depth_filled_linear"], rasters["valid_mask_filled"] 246 | ).clone() 247 | 248 | # Set invalid pixel to far plane 249 | if self.move_invalid_to_far_plane: 250 | if self.depth_transform.far_plane_at_max: 251 | rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( 252 | self.depth_transform.norm_max 253 | ) 254 | else: 255 | rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( 256 | self.depth_transform.norm_min 257 | ) 258 | 259 | # Resize 260 | if self.resize_to_hw is not None: 261 | resize_transform = Resize( 262 | size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT 263 | ) 264 | rasters = {k: resize_transform(v) for k, v in rasters.items()} 265 | return rasters 266 | 267 | def _augment_data(self, rasters_dict): 268 | # lr flipping 269 | lr_flip_p = self.augm_args.lr_flip_p 270 | if random.random() < lr_flip_p: 271 | rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} 272 | 273 | return rasters_dict 274 | 275 | def __del__(self): 276 | if hasattr(self, "tar_obj") and self.tar_obj is not None: 277 | self.tar_obj.close() 278 | self.tar_obj = None 279 | 280 | 281 | def get_pred_name(rgb_basename, name_mode, suffix=".png"): 282 | if DepthFileNameMode.rgb_id == name_mode: 283 | pred_basename = "pred_" + rgb_basename.split("_")[1] 284 | elif DepthFileNameMode.i_d_rgb == name_mode: 285 | pred_basename = rgb_basename.replace("_rgb.", "_pred.") 286 | elif DepthFileNameMode.id == name_mode: 287 | pred_basename = "pred_" + rgb_basename 288 | elif DepthFileNameMode.rgb_i_d == name_mode: 289 | pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) 290 | else: 291 | raise NotImplementedError 292 | # change suffix 293 | pred_basename = os.path.splitext(pred_basename)[0] + suffix 294 | 295 | return pred_basename 296 | -------------------------------------------------------------------------------- /src/dataset/diode_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-26 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import os 24 | import tarfile 25 | from io import BytesIO 26 | 27 | import numpy as np 28 | import torch 29 | 30 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode, DatasetMode 31 | 32 | 33 | class DIODEDataset(BaseDepthDataset): 34 | def __init__( 35 | self, 36 | **kwargs, 37 | ) -> None: 38 | super().__init__( 39 | # DIODE data parameter 40 | min_depth=0.6, 41 | max_depth=350., 42 | has_filled_depth=False, 43 | name_mode=DepthFileNameMode.id, 44 | **kwargs, 45 | ) 46 | 47 | self.processing_res = 640 48 | self.mini_ensemble_size = 10 49 | def _read_npy_file(self, rel_path): 50 | if self.is_tar: 51 | if self.tar_obj is None: 52 | self.tar_obj = tarfile.open(self.dataset_dir) 53 | fileobj = self.tar_obj.extractfile("./" + rel_path) 54 | npy_path_or_content = BytesIO(fileobj.read()) 55 | else: 56 | npy_path_or_content = os.path.join(self.dataset_dir, rel_path) 57 | data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] 58 | return data 59 | 60 | def _read_depth_file(self, rel_path): 61 | depth = self._read_npy_file(rel_path) 62 | return depth 63 | 64 | def _get_data_path(self, index): 65 | return self.filenames[index] 66 | 67 | def _get_data_item(self, index): 68 | # Special: depth mask is read from data 69 | 70 | rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) 71 | 72 | rasters = {} 73 | 74 | # RGB data 75 | rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) 76 | 77 | # Depth data 78 | if DatasetMode.RGB_ONLY != self.mode: 79 | # load data 80 | depth_data = self._load_depth_data( 81 | depth_rel_path=depth_rel_path, filled_rel_path=None 82 | ) 83 | rasters.update(depth_data) 84 | 85 | # valid mask 86 | mask = self._read_npy_file(mask_rel_path).astype(bool) 87 | mask = torch.from_numpy(mask).bool() 88 | rasters["valid_mask_raw"] = mask.clone() 89 | rasters["valid_mask_filled"] = mask.clone() 90 | 91 | other = {"index": index, "rgb_relative_path": rgb_rel_path} 92 | 93 | return rasters, other 94 | -------------------------------------------------------------------------------- /src/dataset/eth3d_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import torch 24 | import tarfile 25 | import os 26 | import numpy as np 27 | 28 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 29 | 30 | 31 | class ETH3DDataset(BaseDepthDataset): 32 | HEIGHT, WIDTH = 4032, 6048 33 | 34 | def __init__( 35 | self, 36 | **kwargs, 37 | ) -> None: 38 | super().__init__( 39 | # ETH3D data parameter 40 | min_depth=1e-5, 41 | max_depth=torch.inf, 42 | has_filled_depth=False, 43 | name_mode=DepthFileNameMode.id, 44 | **kwargs, 45 | ) 46 | self.alignment_max_res = 1024 47 | self.processing_res = 756 48 | self.mini_ensemble_size = 10 49 | 50 | def _read_depth_file(self, rel_path): 51 | # Read special binary data: https://www.eth3d.net/documentation#format-of-multi-view-data-image-formats 52 | if self.is_tar: 53 | if self.tar_obj is None: 54 | self.tar_obj = tarfile.open(self.dataset_dir) 55 | binary_data = self.tar_obj.extractfile("./" + rel_path) 56 | binary_data = binary_data.read() 57 | 58 | else: 59 | depth_path = os.path.join(self.dataset_dir, rel_path) 60 | with open(depth_path, "rb") as file: 61 | binary_data = file.read() 62 | # Convert the binary data to a numpy array of 32-bit floats 63 | depth_decoded = np.frombuffer(binary_data, dtype=np.float32).copy() 64 | 65 | depth_decoded[depth_decoded == torch.inf] = 0.0 66 | 67 | depth_decoded = depth_decoded.reshape((self.HEIGHT, self.WIDTH)) 68 | return depth_decoded 69 | -------------------------------------------------------------------------------- /src/dataset/hypersim_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | 24 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 25 | 26 | 27 | class HypersimDataset(BaseDepthDataset): 28 | def __init__( 29 | self, 30 | **kwargs, 31 | ) -> None: 32 | super().__init__( 33 | # Hypersim data parameter 34 | min_depth=1e-5, 35 | max_depth=65.0, 36 | has_filled_depth=False, 37 | name_mode=DepthFileNameMode.rgb_i_d, 38 | **kwargs, 39 | ) 40 | 41 | def _read_depth_file(self, rel_path): 42 | depth_in = self._read_image(rel_path) 43 | # Decode Hypersim depth 44 | depth_decoded = depth_in / 1000.0 45 | return depth_decoded 46 | -------------------------------------------------------------------------------- /src/dataset/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import torch 24 | 25 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 26 | 27 | 28 | class KITTIDataset(BaseDepthDataset): 29 | def __init__( 30 | self, 31 | kitti_bm_crop, # Crop to KITTI benchmark size 32 | valid_mask_crop, # Evaluation mask. [None, garg or eigen] 33 | **kwargs, 34 | ) -> None: 35 | super().__init__( 36 | # KITTI data parameter 37 | min_depth=1e-5, 38 | max_depth=80, 39 | has_filled_depth=False, 40 | name_mode=DepthFileNameMode.id, 41 | **kwargs, 42 | ) 43 | self.kitti_bm_crop = kitti_bm_crop 44 | self.valid_mask_crop = valid_mask_crop 45 | assert self.valid_mask_crop in [ 46 | None, 47 | "garg", # set evaluation mask according to Garg ECCV16 48 | "eigen", # set evaluation mask according to Eigen NIPS14 49 | ], f"Unknown crop type: {self.valid_mask_crop}" 50 | 51 | self.processing_res = 0 52 | self.mini_ensemble_size = 5 53 | # Filter out empty depth 54 | self.filenames = [f for f in self.filenames if "None" != f[1]] 55 | 56 | def _read_depth_file(self, rel_path): 57 | depth_in = self._read_image(rel_path) 58 | # Decode KITTI depth 59 | depth_decoded = depth_in / 256.0 60 | return depth_decoded 61 | 62 | def _load_rgb_data(self, rgb_rel_path): 63 | rgb_data = super()._load_rgb_data(rgb_rel_path) 64 | if self.kitti_bm_crop: 65 | rgb_data = {k: self.kitti_benchmark_crop(v) for k, v in rgb_data.items()} 66 | return rgb_data 67 | 68 | def _load_depth_data(self, depth_rel_path, filled_rel_path): 69 | depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) 70 | if self.kitti_bm_crop: 71 | depth_data = { 72 | k: self.kitti_benchmark_crop(v) for k, v in depth_data.items() 73 | } 74 | return depth_data 75 | 76 | @staticmethod 77 | def kitti_benchmark_crop(input_img): 78 | """ 79 | Crop images to KITTI benchmark size 80 | Args: 81 | `input_img` (torch.Tensor): Input image to be cropped. 82 | 83 | Returns: 84 | torch.Tensor:Cropped image. 85 | """ 86 | KB_CROP_HEIGHT = 352 87 | KB_CROP_WIDTH = 1216 88 | 89 | height, width = input_img.shape[-2:] 90 | top_margin = int(height - KB_CROP_HEIGHT) 91 | left_margin = int((width - KB_CROP_WIDTH) / 2) 92 | if 2 == len(input_img.shape): 93 | out = input_img[ 94 | top_margin : top_margin + KB_CROP_HEIGHT, 95 | left_margin : left_margin + KB_CROP_WIDTH, 96 | ] 97 | elif 3 == len(input_img.shape): 98 | out = input_img[ 99 | :, 100 | top_margin : top_margin + KB_CROP_HEIGHT, 101 | left_margin : left_margin + KB_CROP_WIDTH, 102 | ] 103 | return out 104 | 105 | def _get_valid_mask(self, depth: torch.Tensor): 106 | # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py 107 | valid_mask = super()._get_valid_mask(depth) # [1, H, W] 108 | 109 | if self.valid_mask_crop is not None: 110 | eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() 111 | gt_height, gt_width = eval_mask.shape 112 | 113 | if "garg" == self.valid_mask_crop: 114 | eval_mask[ 115 | int(0.40810811 * gt_height) : int(0.99189189 * gt_height), 116 | int(0.03594771 * gt_width) : int(0.96405229 * gt_width), 117 | ] = 1 118 | elif "eigen" == self.valid_mask_crop: 119 | eval_mask[ 120 | int(0.3324324 * gt_height) : int(0.91351351 * gt_height), 121 | int(0.0359477 * gt_width) : int(0.96405229 * gt_width), 122 | ] = 1 123 | 124 | eval_mask.reshape(valid_mask.shape) 125 | valid_mask = torch.logical_and(valid_mask, eval_mask) 126 | return valid_mask 127 | -------------------------------------------------------------------------------- /src/dataset/mixed_sampler.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-04-18 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import torch 24 | from torch.utils.data import ( 25 | BatchSampler, 26 | RandomSampler, 27 | SequentialSampler, 28 | ) 29 | 30 | 31 | class MixedBatchSampler(BatchSampler): 32 | """Sample one batch from a selected dataset with given probability. 33 | Compatible with datasets at different resolution 34 | """ 35 | 36 | def __init__( 37 | self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None 38 | ): 39 | self.base_sampler = None 40 | self.batch_size = batch_size 41 | self.shuffle = shuffle 42 | self.drop_last = drop_last 43 | self.generator = generator 44 | 45 | self.src_dataset_ls = src_dataset_ls 46 | self.n_dataset = len(self.src_dataset_ls) 47 | 48 | # Dataset length 49 | self.dataset_length = [len(ds) for ds in self.src_dataset_ls] 50 | self.cum_dataset_length = [ 51 | sum(self.dataset_length[:i]) for i in range(self.n_dataset) 52 | ] # cumulative dataset length 53 | 54 | # BatchSamplers for each source dataset 55 | if self.shuffle: 56 | self.src_batch_samplers = [ 57 | BatchSampler( 58 | sampler=RandomSampler( 59 | ds, replacement=False, generator=self.generator 60 | ), 61 | batch_size=self.batch_size, 62 | drop_last=self.drop_last, 63 | ) 64 | for ds in self.src_dataset_ls 65 | ] 66 | else: 67 | self.src_batch_samplers = [ 68 | BatchSampler( 69 | sampler=SequentialSampler(ds), 70 | batch_size=self.batch_size, 71 | drop_last=self.drop_last, 72 | ) 73 | for ds in self.src_dataset_ls 74 | ] 75 | self.raw_batches = [ 76 | list(bs) for bs in self.src_batch_samplers 77 | ] # index in original dataset 78 | self.n_batches = [len(b) for b in self.raw_batches] 79 | self.n_total_batch = sum(self.n_batches) 80 | 81 | # sampling probability 82 | if prob is None: 83 | # if not given, decide by dataset length 84 | self.prob = torch.tensor(self.n_batches) / self.n_total_batch 85 | else: 86 | self.prob = torch.as_tensor(prob) 87 | 88 | def __iter__(self): 89 | """_summary_ 90 | 91 | Yields: 92 | list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls 93 | """ 94 | for _ in range(self.n_total_batch): 95 | idx_ds = torch.multinomial( 96 | self.prob, 1, replacement=True, generator=self.generator 97 | ).item() 98 | # if batch list is empty, generate new list 99 | if 0 == len(self.raw_batches[idx_ds]): 100 | self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) 101 | # get a batch from list 102 | batch_raw = self.raw_batches[idx_ds].pop() 103 | # shift by cumulative dataset length 104 | shift = self.cum_dataset_length[idx_ds] 105 | batch = [n + shift for n in batch_raw] 106 | 107 | yield batch 108 | 109 | def __len__(self): 110 | return self.n_total_batch 111 | 112 | 113 | # Unit test 114 | if "__main__" == __name__: 115 | from torch.utils.data import ConcatDataset, DataLoader, Dataset 116 | 117 | class SimpleDataset(Dataset): 118 | def __init__(self, start, len) -> None: 119 | super().__init__() 120 | self.start = start 121 | self.len = len 122 | 123 | def __len__(self): 124 | return self.len 125 | 126 | def __getitem__(self, index): 127 | return self.start + index 128 | 129 | dataset_1 = SimpleDataset(0, 10) 130 | dataset_2 = SimpleDataset(200, 20) 131 | dataset_3 = SimpleDataset(1000, 50) 132 | 133 | concat_dataset = ConcatDataset( 134 | [dataset_1, dataset_2, dataset_3] 135 | ) # will directly concatenate 136 | 137 | mixed_sampler = MixedBatchSampler( 138 | src_dataset_ls=[dataset_1, dataset_2, dataset_3], 139 | batch_size=4, 140 | drop_last=True, 141 | shuffle=False, 142 | prob=[0.6, 0.3, 0.1], 143 | generator=torch.Generator().manual_seed(0), 144 | ) 145 | 146 | loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) 147 | 148 | for d in loader: 149 | print(d) 150 | -------------------------------------------------------------------------------- /src/dataset/nyu_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import torch 24 | 25 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 26 | 27 | 28 | class NYUDataset(BaseDepthDataset): 29 | def __init__( 30 | self, 31 | eigen_valid_mask: bool, 32 | **kwargs, 33 | ) -> None: 34 | super().__init__( 35 | # NYUv2 dataset parameter 36 | min_depth=1e-3, 37 | max_depth=10.0, 38 | has_filled_depth=True, 39 | name_mode=DepthFileNameMode.rgb_id, 40 | **kwargs, 41 | ) 42 | 43 | self.eigen_valid_mask = eigen_valid_mask 44 | 45 | self.processing_res = 0 46 | self.mini_ensemble_size = 10 47 | def _read_depth_file(self, rel_path): 48 | depth_in = self._read_image(rel_path) 49 | # Decode NYU depth 50 | depth_decoded = depth_in / 1000.0 51 | return depth_decoded 52 | 53 | def _get_valid_mask(self, depth: torch.Tensor): 54 | valid_mask = super()._get_valid_mask(depth) 55 | 56 | # Eigen crop for evaluation 57 | if self.eigen_valid_mask: 58 | eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() 59 | eval_mask[45:471, 41:601] = 1 60 | eval_mask.reshape(valid_mask.shape) 61 | valid_mask = torch.logical_and(valid_mask, eval_mask) 62 | 63 | return valid_mask 64 | -------------------------------------------------------------------------------- /src/dataset/scannet_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 24 | 25 | 26 | class ScanNetDataset(BaseDepthDataset): 27 | def __init__( 28 | self, 29 | **kwargs, 30 | ) -> None: 31 | super().__init__( 32 | # ScanNet data parameter 33 | min_depth=1e-3, 34 | max_depth=10, 35 | has_filled_depth=False, 36 | name_mode=DepthFileNameMode.id, 37 | **kwargs, 38 | ) 39 | 40 | self.processing_res = 0 41 | self.mini_ensemble_size = 10 42 | def _read_depth_file(self, rel_path): 43 | depth_in = self._read_image(rel_path) 44 | # Decode ScanNet depth 45 | depth_decoded = depth_in / 1000.0 46 | return depth_decoded 47 | -------------------------------------------------------------------------------- /src/dataset/tartanair_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import os 24 | import tarfile 25 | from io import BytesIO 26 | 27 | import numpy as np 28 | 29 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 30 | 31 | 32 | class TartanairDataset(BaseDepthDataset): 33 | def __init__( 34 | self, 35 | **kwargs, 36 | ) -> None: 37 | super().__init__( 38 | # Tartanair data parameter 39 | min_depth=1e-5, 40 | max_depth=70.0, 41 | has_filled_depth=False, 42 | name_mode=DepthFileNameMode.rgb_i_d, 43 | **kwargs, 44 | ) 45 | 46 | def _read_npy_file(self, rel_path): 47 | npy_path_or_content = os.path.join(self.dataset_dir, rel_path) 48 | data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] 49 | return data 50 | 51 | def _read_depth_file(self, rel_path): 52 | depth = self._read_npy_file(rel_path) 53 | return depth 54 | 55 | -------------------------------------------------------------------------------- /src/dataset/vkitti_dataset.py: -------------------------------------------------------------------------------- 1 | # Last modified: 2024-02-08 2 | # 3 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # -------------------------------------------------------------------------- 17 | # If you find this code useful, we kindly ask you to cite our paper in your work. 18 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 19 | # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. 20 | # More information about the method can be found at https://marigoldmonodepth.github.io 21 | # -------------------------------------------------------------------------- 22 | 23 | import torch 24 | 25 | from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode 26 | from .kitti_dataset import KITTIDataset 27 | 28 | 29 | class VirtualKITTIDataset(BaseDepthDataset): 30 | def __init__( 31 | self, 32 | kitti_bm_crop, # Crop to KITTI benchmark size 33 | valid_mask_crop, # Evaluation mask. [None, garg or eigen] 34 | **kwargs, 35 | ) -> None: 36 | super().__init__( 37 | # virtual KITTI data parameter 38 | min_depth=1e-5, 39 | max_depth=80, # 655.35 40 | has_filled_depth=False, 41 | name_mode=DepthFileNameMode.id, 42 | **kwargs, 43 | ) 44 | self.kitti_bm_crop = kitti_bm_crop 45 | self.valid_mask_crop = valid_mask_crop 46 | assert self.valid_mask_crop in [ 47 | None, 48 | "garg", # set evaluation mask according to Garg ECCV16 49 | "eigen", # set evaluation mask according to Eigen NIPS14 50 | ], f"Unknown crop type: {self.valid_mask_crop}" 51 | 52 | # Filter out empty depth 53 | self.filenames = [f for f in self.filenames if "None" != f[1]] 54 | 55 | def _read_depth_file(self, rel_path): 56 | depth_in = self._read_image(rel_path) 57 | # Decode vKITTI depth 58 | depth_decoded = depth_in / 100.0 59 | return depth_decoded 60 | 61 | def _load_rgb_data(self, rgb_rel_path): 62 | rgb_data = super()._load_rgb_data(rgb_rel_path) 63 | if self.kitti_bm_crop: 64 | rgb_data = { 65 | k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items() 66 | } 67 | return rgb_data 68 | 69 | def _load_depth_data(self, depth_rel_path, filled_rel_path): 70 | depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) 71 | if self.kitti_bm_crop: 72 | depth_data = { 73 | k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items() 74 | } 75 | return depth_data 76 | 77 | def _get_valid_mask(self, depth: torch.Tensor): 78 | # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py 79 | valid_mask = super()._get_valid_mask(depth) # [1, H, W] 80 | 81 | if self.valid_mask_crop is not None: 82 | eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() 83 | gt_height, gt_width = eval_mask.shape 84 | 85 | if "garg" == self.valid_mask_crop: 86 | eval_mask[ 87 | int(0.40810811 * gt_height) : int(0.99189189 * gt_height), 88 | int(0.03594771 * gt_width) : int(0.96405229 * gt_width), 89 | ] = 1 90 | elif "eigen" == self.valid_mask_crop: 91 | eval_mask[ 92 | int(0.3324324 * gt_height) : int(0.91351351 * gt_height), 93 | int(0.0359477 * gt_width) : int(0.96405229 * gt_width), 94 | ] = 1 95 | 96 | eval_mask.reshape(valid_mask.shape) 97 | valid_mask = torch.logical_and(valid_mask, eval_mask) 98 | return valid_mask 99 | -------------------------------------------------------------------------------- /src/util/alignment.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-01-11 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def align_depth_least_square( 9 | gt_arr: np.ndarray, 10 | pred_arr: np.ndarray, 11 | valid_mask_arr: np.ndarray, 12 | return_scale_shift=True, 13 | max_resolution=None, 14 | ): 15 | ori_shape = pred_arr.shape # input shape 16 | 17 | gt = gt_arr.squeeze() # [H, W] 18 | pred = pred_arr.squeeze() 19 | valid_mask = valid_mask_arr.squeeze() 20 | 21 | # Downsample 22 | if max_resolution is not None: 23 | scale_factor = np.min(max_resolution / np.array(ori_shape[-2:])) 24 | if scale_factor < 1: 25 | downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") 26 | gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy() 27 | pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy() 28 | valid_mask = ( 29 | downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float()) 30 | .bool() 31 | .numpy() 32 | ) 33 | 34 | assert ( 35 | gt.shape == pred.shape == valid_mask.shape 36 | ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}" 37 | 38 | gt_masked = gt[valid_mask].reshape((-1, 1)) 39 | pred_masked = pred[valid_mask].reshape((-1, 1)) 40 | 41 | # numpy solver 42 | _ones = np.ones_like(pred_masked) 43 | A = np.concatenate([pred_masked, _ones], axis=-1) 44 | X = np.linalg.lstsq(A, gt_masked, rcond=None)[0] 45 | scale, shift = X 46 | 47 | aligned_pred = pred_arr * scale + shift 48 | 49 | # restore dimensions 50 | aligned_pred = aligned_pred.reshape(ori_shape) 51 | 52 | if return_scale_shift: 53 | return aligned_pred, scale, shift 54 | else: 55 | return aligned_pred 56 | 57 | 58 | # ******************** disparity space ******************** 59 | def depth2disparity(depth, return_mask=False): 60 | if isinstance(depth, torch.Tensor): 61 | disparity = torch.zeros_like(depth) 62 | elif isinstance(depth, np.ndarray): 63 | disparity = np.zeros_like(depth) 64 | non_negtive_mask = depth > 0 65 | disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] 66 | if return_mask: 67 | return disparity, non_negtive_mask 68 | else: 69 | return disparity 70 | 71 | 72 | def disparity2depth(disparity, **kwargs): 73 | return depth2disparity(disparity, **kwargs) 74 | -------------------------------------------------------------------------------- /src/util/config_util.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-02-14 3 | 4 | import omegaconf 5 | from omegaconf import OmegaConf 6 | 7 | 8 | def recursive_load_config(config_path: str) -> OmegaConf: 9 | conf = OmegaConf.load(config_path) 10 | 11 | output_conf = OmegaConf.create({}) 12 | 13 | # Load base config. Later configs on the list will overwrite previous 14 | base_configs = conf.get("base_config", default_value=None) 15 | if base_configs is not None: 16 | assert isinstance(base_configs, omegaconf.listconfig.ListConfig) 17 | for _path in base_configs: 18 | assert ( 19 | _path != config_path 20 | ), "Circulate merging, base_config should not include itself." 21 | _base_conf = recursive_load_config(_path) 22 | output_conf = OmegaConf.merge(output_conf, _base_conf) 23 | 24 | # Merge configs and overwrite values 25 | output_conf = OmegaConf.merge(output_conf, conf) 26 | 27 | return output_conf 28 | 29 | 30 | def find_value_in_omegaconf(search_key, config): 31 | result_list = [] 32 | 33 | if isinstance(config, omegaconf.DictConfig): 34 | for key, value in config.items(): 35 | if key == search_key: 36 | result_list.append(value) 37 | elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)): 38 | result_list.extend(find_value_in_omegaconf(search_key, value)) 39 | elif isinstance(config, omegaconf.ListConfig): 40 | for item in config: 41 | if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)): 42 | result_list.extend(find_value_in_omegaconf(search_key, item)) 43 | 44 | return result_list 45 | 46 | 47 | if "__main__" == __name__: 48 | conf = recursive_load_config("config/train_base.yaml") 49 | print(OmegaConf.to_yaml(conf)) 50 | -------------------------------------------------------------------------------- /src/util/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py 2 | 3 | from torch.utils.data import BatchSampler, DataLoader, IterableDataset 4 | 5 | # kwargs of the DataLoader in min version 1.4.0. 6 | _PYTORCH_DATALOADER_KWARGS = { 7 | "batch_size": 1, 8 | "shuffle": False, 9 | "sampler": None, 10 | "batch_sampler": None, 11 | "num_workers": 0, 12 | "collate_fn": None, 13 | "pin_memory": False, 14 | "drop_last": False, 15 | "timeout": 0, 16 | "worker_init_fn": None, 17 | "multiprocessing_context": None, 18 | "generator": None, 19 | "prefetch_factor": 2, 20 | "persistent_workers": False, 21 | } 22 | 23 | 24 | class SkipBatchSampler(BatchSampler): 25 | """ 26 | A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. 27 | """ 28 | 29 | def __init__(self, batch_sampler, skip_batches=0): 30 | self.batch_sampler = batch_sampler 31 | self.skip_batches = skip_batches 32 | 33 | def __iter__(self): 34 | for index, samples in enumerate(self.batch_sampler): 35 | if index >= self.skip_batches: 36 | yield samples 37 | 38 | @property 39 | def total_length(self): 40 | return len(self.batch_sampler) 41 | 42 | def __len__(self): 43 | return len(self.batch_sampler) - self.skip_batches 44 | 45 | 46 | class SkipDataLoader(DataLoader): 47 | """ 48 | Subclass of a PyTorch `DataLoader` that will skip the first batches. 49 | 50 | Args: 51 | dataset (`torch.utils.data.dataset.Dataset`): 52 | The dataset to use to build this datalaoder. 53 | skip_batches (`int`, *optional*, defaults to 0): 54 | The number of batches to skip at the beginning. 55 | kwargs: 56 | All other keyword arguments to pass to the regular `DataLoader` initialization. 57 | """ 58 | 59 | def __init__(self, dataset, skip_batches=0, **kwargs): 60 | super().__init__(dataset, **kwargs) 61 | self.skip_batches = skip_batches 62 | 63 | def __iter__(self): 64 | for index, batch in enumerate(super().__iter__()): 65 | if index >= self.skip_batches: 66 | yield batch 67 | 68 | 69 | # Adapted from https://github.com/huggingface/accelerate 70 | def skip_first_batches(dataloader, num_batches=0): 71 | """ 72 | Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. 73 | """ 74 | dataset = dataloader.dataset 75 | sampler_is_batch_sampler = False 76 | if isinstance(dataset, IterableDataset): 77 | new_batch_sampler = None 78 | else: 79 | sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) 80 | batch_sampler = ( 81 | dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler 82 | ) 83 | new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) 84 | 85 | # We ignore all of those since they are all dealt with by our new_batch_sampler 86 | ignore_kwargs = [ 87 | "batch_size", 88 | "shuffle", 89 | "sampler", 90 | "batch_sampler", 91 | "drop_last", 92 | ] 93 | 94 | kwargs = { 95 | k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) 96 | for k in _PYTORCH_DATALOADER_KWARGS 97 | if k not in ignore_kwargs 98 | } 99 | 100 | # Need to provide batch_size as batch_sampler is None for Iterable dataset 101 | if new_batch_sampler is None: 102 | kwargs["drop_last"] = dataloader.drop_last 103 | kwargs["batch_size"] = dataloader.batch_size 104 | 105 | if new_batch_sampler is None: 106 | # Need to manually skip batches in the dataloader 107 | dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) 108 | else: 109 | dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) 110 | 111 | return dataloader 112 | -------------------------------------------------------------------------------- /src/util/depth_transform.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-04-18 3 | 4 | import torch 5 | import logging 6 | 7 | 8 | def get_depth_normalizer(cfg_normalizer): 9 | if cfg_normalizer is None: 10 | 11 | def identical(x): 12 | return x 13 | 14 | depth_transform = identical 15 | 16 | elif "scale_shift_depth" == cfg_normalizer.type: 17 | depth_transform = ScaleShiftDepthNormalizer( 18 | norm_min=cfg_normalizer.norm_min, 19 | norm_max=cfg_normalizer.norm_max, 20 | min_max_quantile=cfg_normalizer.min_max_quantile, 21 | clip=cfg_normalizer.clip, 22 | ) 23 | else: 24 | raise NotImplementedError 25 | return depth_transform 26 | 27 | 28 | class DepthNormalizerBase: 29 | is_absolute = None 30 | far_plane_at_max = None 31 | 32 | def __init__( 33 | self, 34 | norm_min=-1.0, 35 | norm_max=1.0, 36 | ) -> None: 37 | self.norm_min = norm_min 38 | self.norm_max = norm_max 39 | raise NotImplementedError 40 | 41 | def __call__(self, depth, valid_mask=None, clip=None): 42 | raise NotImplementedError 43 | 44 | def denormalize(self, depth_norm, **kwargs): 45 | # For metric depth: convert prediction back to metric depth 46 | # For relative depth: convert prediction to [0, 1] 47 | raise NotImplementedError 48 | 49 | 50 | class ScaleShiftDepthNormalizer(DepthNormalizerBase): 51 | """ 52 | Use near and far plane to linearly normalize depth, 53 | i.e. d' = d * s + t, 54 | where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max` 55 | Near and far planes are determined by taking quantile values. 56 | """ 57 | 58 | is_absolute = False 59 | far_plane_at_max = True 60 | 61 | def __init__( 62 | self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True 63 | ) -> None: 64 | self.norm_min = norm_min 65 | self.norm_max = norm_max 66 | self.norm_range = self.norm_max - self.norm_min 67 | self.min_quantile = min_max_quantile 68 | self.max_quantile = 1.0 - self.min_quantile 69 | self.clip = clip 70 | 71 | def __call__(self, depth_linear, valid_mask=None, clip=None): 72 | clip = clip if clip is not None else self.clip 73 | 74 | if valid_mask is None: 75 | valid_mask = torch.ones_like(depth_linear).bool() 76 | valid_mask = valid_mask & (depth_linear > 0) 77 | 78 | # Take quantiles as min and max 79 | _min, _max = torch.quantile( 80 | depth_linear[valid_mask], 81 | torch.tensor([self.min_quantile, self.max_quantile]), 82 | ) 83 | 84 | # scale and shift 85 | depth_norm_linear = (depth_linear - _min) / ( 86 | _max - _min 87 | ) * self.norm_range + self.norm_min 88 | 89 | if clip: 90 | depth_norm_linear = torch.clip( 91 | depth_norm_linear, self.norm_min, self.norm_max 92 | ) 93 | 94 | return depth_norm_linear 95 | 96 | def scale_back(self, depth_norm): 97 | # scale to [0, 1] 98 | depth_linear = (depth_norm - self.norm_min) / self.norm_range 99 | return depth_linear 100 | 101 | def denormalize(self, depth_norm, **kwargs): 102 | logging.warning(f"{self.__class__} is not revertible without GT") 103 | return self.scale_back(depth_norm=depth_norm) 104 | -------------------------------------------------------------------------------- /src/util/logging_util.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-03-12 3 | 4 | import logging 5 | import os 6 | import sys 7 | import wandb 8 | from tabulate import tabulate 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | def config_logging(cfg_logging, out_dir=None): 13 | file_level = cfg_logging.get("file_level", 10) 14 | console_level = cfg_logging.get("console_level", 10) 15 | 16 | log_formatter = logging.Formatter(cfg_logging["format"]) 17 | 18 | root_logger = logging.getLogger() 19 | root_logger.handlers.clear() 20 | 21 | root_logger.setLevel(min(file_level, console_level)) 22 | 23 | if out_dir is not None: 24 | _logging_file = os.path.join( 25 | out_dir, cfg_logging.get("filename", "logging.log") 26 | ) 27 | file_handler = logging.FileHandler(_logging_file) 28 | file_handler.setFormatter(log_formatter) 29 | file_handler.setLevel(file_level) 30 | root_logger.addHandler(file_handler) 31 | 32 | console_handler = logging.StreamHandler(sys.stdout) 33 | console_handler.setFormatter(log_formatter) 34 | console_handler.setLevel(console_level) 35 | root_logger.addHandler(console_handler) 36 | 37 | # Avoid pollution by packages 38 | logging.getLogger("PIL").setLevel(logging.INFO) 39 | logging.getLogger("matplotlib").setLevel(logging.INFO) 40 | 41 | 42 | class MyTrainingLogger: 43 | """Tensorboard + wandb logger""" 44 | 45 | writer: SummaryWriter 46 | is_initialized = False 47 | 48 | def __init__(self) -> None: 49 | pass 50 | 51 | def set_dir(self, tb_log_dir): 52 | if self.is_initialized: 53 | raise ValueError("Do not initialize writer twice") 54 | self.writer = SummaryWriter(tb_log_dir) 55 | self.is_initialized = True 56 | 57 | def log_dic(self, scalar_dic, global_step, walltime=None): 58 | for k, v in scalar_dic.items(): 59 | self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) 60 | return 61 | 62 | 63 | # global instance 64 | tb_logger = MyTrainingLogger() 65 | 66 | 67 | # -------------- wandb tools -------------- 68 | def init_wandb(enable: bool, **kwargs): 69 | if enable: 70 | run = wandb.init(sync_tensorboard=True, **kwargs) 71 | else: 72 | run = wandb.init(mode="disabled") 73 | return run 74 | 75 | 76 | def log_slurm_job_id(step): 77 | global tb_logger 78 | _jobid = os.getenv("SLURM_JOB_ID") 79 | if _jobid is None: 80 | _jobid = -1 81 | tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) 82 | logging.debug(f"Slurm job_id: {_jobid}") 83 | 84 | 85 | def load_wandb_job_id(out_dir): 86 | with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: 87 | wandb_id = f.read() 88 | return wandb_id 89 | 90 | 91 | def save_wandb_job_id(run, out_dir): 92 | with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: 93 | f.write(run.id) 94 | 95 | 96 | def eval_dic_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): 97 | eval_text = f"Evaluation metrics:\n\ 98 | on dataset: {dataset_name}\n\ 99 | over samples in: {sample_list_path}\n" 100 | 101 | eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) 102 | return eval_text 103 | -------------------------------------------------------------------------------- /src/util/loss.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-02-22 3 | 4 | import torch 5 | 6 | 7 | def get_loss(loss_name, **kwargs): 8 | if "silog_mse" == loss_name: 9 | criterion = SILogMSELoss(**kwargs) 10 | elif "silog_rmse" == loss_name: 11 | criterion = SILogRMSELoss(**kwargs) 12 | elif "mse_loss" == loss_name: 13 | criterion = torch.nn.MSELoss(**kwargs) 14 | elif "l1_loss" == loss_name: 15 | criterion = torch.nn.L1Loss(**kwargs) 16 | elif "l1_loss_with_mask" == loss_name: 17 | criterion = L1LossWithMask(**kwargs) 18 | elif "mean_abs_rel" == loss_name: 19 | criterion = MeanAbsRelLoss() 20 | else: 21 | raise NotImplementedError 22 | 23 | return criterion 24 | 25 | 26 | class L1LossWithMask: 27 | def __init__(self, batch_reduction=False): 28 | self.batch_reduction = batch_reduction 29 | 30 | def __call__(self, depth_pred, depth_gt, valid_mask=None): 31 | diff = depth_pred - depth_gt 32 | if valid_mask is not None: 33 | diff[~valid_mask] = 0 34 | n = valid_mask.sum((-1, -2)) 35 | else: 36 | n = depth_gt.shape[-2] * depth_gt.shape[-1] 37 | 38 | loss = torch.sum(torch.abs(diff)) / n 39 | if self.batch_reduction: 40 | loss = loss.mean() 41 | return loss 42 | 43 | 44 | class MeanAbsRelLoss: 45 | def __init__(self) -> None: 46 | # super().__init__() 47 | pass 48 | 49 | def __call__(self, pred, gt): 50 | diff = pred - gt 51 | rel_abs = torch.abs(diff / gt) 52 | loss = torch.mean(rel_abs, dim=0) 53 | return loss 54 | 55 | 56 | class SILogMSELoss: 57 | def __init__(self, lamb, log_pred=True, batch_reduction=True): 58 | """Scale Invariant Log MSE Loss 59 | 60 | Args: 61 | lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss 62 | log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred 63 | """ 64 | super(SILogMSELoss, self).__init__() 65 | self.lamb = lamb 66 | self.pred_in_log = log_pred 67 | self.batch_reduction = batch_reduction 68 | 69 | def __call__(self, depth_pred, depth_gt, valid_mask=None): 70 | log_depth_pred = ( 71 | depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) 72 | ) 73 | log_depth_gt = torch.log(depth_gt) 74 | 75 | diff = log_depth_pred - log_depth_gt 76 | if valid_mask is not None: 77 | diff[~valid_mask] = 0 78 | n = valid_mask.sum((-1, -2)) 79 | else: 80 | n = depth_gt.shape[-2] * depth_gt.shape[-1] 81 | 82 | diff2 = torch.pow(diff, 2) 83 | 84 | first_term = torch.sum(diff2, (-1, -2)) / n 85 | second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) 86 | loss = first_term - second_term 87 | if self.batch_reduction: 88 | loss = loss.mean() 89 | return loss 90 | 91 | 92 | class SILogRMSELoss: 93 | def __init__(self, lamb, alpha, log_pred=True): 94 | """Scale Invariant Log RMSE Loss 95 | 96 | Args: 97 | lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss 98 | alpha: 99 | log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred 100 | """ 101 | super(SILogRMSELoss, self).__init__() 102 | self.lamb = lamb 103 | self.alpha = alpha 104 | self.pred_in_log = log_pred 105 | 106 | def __call__(self, depth_pred, depth_gt, valid_mask): 107 | log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) 108 | log_depth_gt = torch.log(depth_gt) 109 | # borrowed from https://github.com/aliyun/NeWCRFs 110 | # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] 111 | # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha 112 | 113 | diff = log_depth_pred - log_depth_gt 114 | if valid_mask is not None: 115 | diff[~valid_mask] = 0 116 | n = valid_mask.sum((-1, -2)) 117 | else: 118 | n = depth_gt.shape[-2] * depth_gt.shape[-1] 119 | 120 | diff2 = torch.pow(diff, 2) 121 | first_term = torch.sum(diff2, (-1, -2)) / n 122 | second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) 123 | loss = torch.sqrt(first_term - second_term).mean() * self.alpha 124 | return loss 125 | -------------------------------------------------------------------------------- /src/util/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-02-22 3 | 4 | import numpy as np 5 | 6 | 7 | class IterExponential: 8 | def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: 9 | """ 10 | Customized iteration-wise exponential scheduler. 11 | Re-calculate for every step, to reduce error accumulation 12 | 13 | Args: 14 | total_iter_length (int): Expected total iteration number 15 | final_ratio (float): Expected LR ratio at n_iter = total_iter_length 16 | """ 17 | self.total_length = total_iter_length 18 | self.effective_length = total_iter_length - warmup_steps 19 | self.final_ratio = final_ratio 20 | self.warmup_steps = warmup_steps 21 | 22 | def __call__(self, n_iter) -> float: 23 | if n_iter < self.warmup_steps: 24 | alpha = 1.0 * n_iter / self.warmup_steps 25 | elif n_iter >= self.total_length: 26 | alpha = self.final_ratio 27 | else: 28 | actual_iter = n_iter - self.warmup_steps 29 | alpha = np.exp( 30 | actual_iter / self.effective_length * np.log(self.final_ratio) 31 | ) 32 | return alpha 33 | 34 | 35 | if "__main__" == __name__: 36 | lr_scheduler = IterExponential( 37 | total_iter_length=50000, final_ratio=0.01, warmup_steps=200 38 | ) 39 | lr_scheduler = IterExponential( 40 | total_iter_length=50000, final_ratio=0.01, warmup_steps=0 41 | ) 42 | 43 | x = np.arange(100000) 44 | alphas = [lr_scheduler(i) for i in x] 45 | import matplotlib.pyplot as plt 46 | 47 | plt.plot(alphas) 48 | plt.savefig("lr_scheduler.png") 49 | -------------------------------------------------------------------------------- /src/util/metric.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-02-15 3 | 4 | 5 | import pandas as pd 6 | import torch 7 | 8 | 9 | # Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py 10 | class MetricTracker: 11 | def __init__(self, *keys, writer=None): 12 | self.writer = writer 13 | self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) 14 | self.reset() 15 | 16 | def reset(self): 17 | for col in self._data.columns: 18 | self._data[col].values[:] = 0 19 | 20 | def update(self, key, value, n=1): 21 | if self.writer is not None: 22 | self.writer.add_scalar(key, value) 23 | self._data.loc[key, "total"] += value * n 24 | self._data.loc[key, "counts"] += n 25 | self._data.loc[key, "average"] = self._data.total[key] / self._data.counts[key] 26 | 27 | def avg(self, key): 28 | return self._data.average[key] 29 | 30 | def result(self): 31 | return dict(self._data.average) 32 | 33 | 34 | def abs_relative_difference(output, target, valid_mask=None): 35 | actual_output = output 36 | actual_target = target 37 | abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target 38 | if valid_mask is not None: 39 | abs_relative_diff[~valid_mask] = 0 40 | n = valid_mask.sum((-1, -2)) 41 | else: 42 | n = output.shape[-1] * output.shape[-2] 43 | abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n 44 | return abs_relative_diff.mean() * 100 45 | 46 | 47 | def squared_relative_difference(output, target, valid_mask=None): 48 | actual_output = output 49 | actual_target = target 50 | square_relative_diff = ( 51 | torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target 52 | ) 53 | if valid_mask is not None: 54 | square_relative_diff[~valid_mask] = 0 55 | n = valid_mask.sum((-1, -2)) 56 | else: 57 | n = output.shape[-1] * output.shape[-2] 58 | square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n 59 | return square_relative_diff.mean() *100 60 | 61 | 62 | def rmse_linear(output, target, valid_mask=None): 63 | actual_output = output 64 | actual_target = target 65 | diff = actual_output - actual_target 66 | if valid_mask is not None: 67 | diff[~valid_mask] = 0 68 | n = valid_mask.sum((-1, -2)) 69 | else: 70 | n = output.shape[-1] * output.shape[-2] 71 | diff2 = torch.pow(diff, 2) 72 | mse = torch.sum(diff2, (-1, -2)) / n 73 | rmse = torch.sqrt(mse) 74 | return rmse.mean() *100 75 | 76 | 77 | def rmse_log(output, target, valid_mask=None): 78 | diff = torch.log(output) - torch.log(target) 79 | if valid_mask is not None: 80 | diff[~valid_mask] = 0 81 | n = valid_mask.sum((-1, -2)) 82 | else: 83 | n = output.shape[-1] * output.shape[-2] 84 | diff2 = torch.pow(diff, 2) 85 | mse = torch.sum(diff2, (-1, -2)) / n # [B] 86 | rmse = torch.sqrt(mse) 87 | return rmse.mean() *100 88 | 89 | def log10(output, target, valid_mask=None): 90 | if valid_mask is not None: 91 | diff = torch.abs( 92 | torch.log10(output[valid_mask]) - torch.log10(target[valid_mask]) 93 | ) 94 | else: 95 | diff = torch.abs(torch.log10(output) - torch.log10(target)) 96 | return diff.mean() * 100 97 | 98 | 99 | # adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py 100 | def threshold_percentage(output, target, threshold_val, valid_mask=None): 101 | d1 = output / target 102 | d2 = target / output 103 | max_d1_d2 = torch.max(d1, d2) 104 | zero = torch.zeros(*output.shape) 105 | one = torch.ones(*output.shape) 106 | bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero) 107 | if valid_mask is not None: 108 | bit_mat[~valid_mask] = 0 109 | n = valid_mask.sum((-1, -2)) 110 | else: 111 | n = output.shape[-1] * output.shape[-2] 112 | count_mat = torch.sum(bit_mat, (-1, -2)) 113 | threshold_mat = count_mat / n.cpu() 114 | return threshold_mat.mean() * 100 115 | 116 | 117 | def delta1_acc(pred, gt, valid_mask): 118 | return threshold_percentage(pred, gt, 1.25, valid_mask) 119 | 120 | 121 | def delta2_acc(pred, gt, valid_mask): 122 | return threshold_percentage(pred, gt, 1.25**2, valid_mask) 123 | 124 | 125 | def delta3_acc(pred, gt, valid_mask): 126 | return threshold_percentage(pred, gt, 1.25**3, valid_mask) 127 | 128 | 129 | def i_rmse(output, target, valid_mask=None): 130 | output_inv = 1.0 / output 131 | target_inv = 1.0 / target 132 | diff = output_inv - target_inv 133 | if valid_mask is not None: 134 | diff[~valid_mask] = 0 135 | n = valid_mask.sum((-1, -2)) 136 | else: 137 | n = output.shape[-1] * output.shape[-2] 138 | diff2 = torch.pow(diff, 2) 139 | mse = torch.sum(diff2, (-1, -2)) / n # [B] 140 | rmse = torch.sqrt(mse) 141 | return rmse.mean() * 100 142 | 143 | 144 | def silog_rmse(depth_pred, depth_gt, valid_mask=None): 145 | diff = torch.log(depth_pred) - torch.log(depth_gt) 146 | if valid_mask is not None: 147 | diff[~valid_mask] = 0 148 | n = valid_mask.sum((-1, -2)) 149 | else: 150 | n = depth_gt.shape[-2] * depth_gt.shape[-1] 151 | 152 | diff2 = torch.pow(diff, 2) 153 | 154 | first_term = torch.sum(diff2, (-1, -2)) / n 155 | second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) 156 | loss = torch.sqrt(torch.mean(first_term - second_term)) * 100 157 | return loss * 100 158 | -------------------------------------------------------------------------------- /src/util/multi_res_noise.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-04-18 3 | 4 | import torch 5 | import math 6 | 7 | 8 | # adapted from: https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 9 | def multi_res_noise_like( 10 | x, strength=0.9, downscale_strategy="original", generator=None, device=None 11 | ): 12 | if torch.is_tensor(strength): 13 | strength = strength.reshape((-1, 1, 1, 1)) 14 | b, c, w, h = x.shape 15 | 16 | if device is None: 17 | device = x.device 18 | 19 | up_sampler = torch.nn.Upsample(size=(w, h), mode="bilinear") 20 | noise = torch.randn(x.shape, device=x.device, generator=generator) 21 | 22 | if "original" == downscale_strategy: 23 | for i in range(10): 24 | r = ( 25 | torch.rand(1, generator=generator, device=device) * 2 + 2 26 | ) # Rather than always going 2x, 27 | w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) 28 | noise += ( 29 | up_sampler( 30 | torch.randn(b, c, w, h, generator=generator, device=device).to(x) 31 | ) 32 | * strength**i 33 | ) 34 | if w == 1 or h == 1: 35 | break # Lowest resolution is 1x1 36 | elif "every_layer" == downscale_strategy: 37 | for i in range(int(math.log2(min(w, h)))): 38 | w, h = max(1, int(w / 2)), max(1, int(h / 2)) 39 | noise += ( 40 | up_sampler( 41 | torch.randn(b, c, w, h, generator=generator, device=device).to(x) 42 | ) 43 | * strength**i 44 | ) 45 | elif "power_of_two" == downscale_strategy: 46 | for i in range(10): 47 | r = 2 48 | w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) 49 | noise += ( 50 | up_sampler( 51 | torch.randn(b, c, w, h, generator=generator, device=device).to(x) 52 | ) 53 | * strength**i 54 | ) 55 | if w == 1 or h == 1: 56 | break # Lowest resolution is 1x1 57 | elif "random_step" == downscale_strategy: 58 | for i in range(10): 59 | r = ( 60 | torch.rand(1, generator=generator, device=device) * 2 + 2 61 | ) # Rather than always going 2x, 62 | w, h = max(1, int(w / (r))), max(1, int(h / (r))) 63 | noise += ( 64 | up_sampler( 65 | torch.randn(b, c, w, h, generator=generator, device=device).to(x) 66 | ) 67 | * strength**i 68 | ) 69 | if w == 1 or h == 1: 70 | break # Lowest resolution is 1x1 71 | else: 72 | raise ValueError(f"unknown downscale strategy: {downscale_strategy}") 73 | 74 | noise = noise / noise.std() # Scaled back to roughly unit variance 75 | return noise 76 | -------------------------------------------------------------------------------- /src/util/seeding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | import numpy as np 22 | import random 23 | import torch 24 | import logging 25 | 26 | 27 | def seed_all(seed: int = 0): 28 | """ 29 | Set random seeds of all components. 30 | """ 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | 37 | def generate_seed_sequence( 38 | initial_seed: int, 39 | length: int, 40 | min_val=-0x8000_0000_0000_0000, 41 | max_val=0xFFFF_FFFF_FFFF_FFFF, 42 | ): 43 | if initial_seed is None: 44 | logging.warning("initial_seed is None, reproducibility is not guaranteed") 45 | random.seed(initial_seed) 46 | 47 | seed_sequence = [] 48 | 49 | for _ in range(length): 50 | seed = random.randint(min_val, max_val) 51 | 52 | seed_sequence.append(seed) 53 | 54 | return seed_sequence 55 | -------------------------------------------------------------------------------- /src/util/slurm_util.py: -------------------------------------------------------------------------------- 1 | # Author: Bingxin Ke 2 | # Last modified: 2024-02-22 3 | 4 | import os 5 | 6 | 7 | def is_on_slurm(): 8 | cluster_name = os.getenv("SLURM_CLUSTER_NAME") 9 | is_on_slurm = cluster_name is not None 10 | return is_on_slurm 11 | 12 | 13 | def get_local_scratch_dir(): 14 | local_scratch_dir = os.getenv("TMPDIR") 15 | return local_scratch_dir 16 | -------------------------------------------------------------------------------- /train_scripts/train_merge_b_depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | accelerate launch --num_processes=8 --main_process_port=36661 merge/train_merge_base_depth.py \ 4 | --dataloader_num_workers=8 --max_train_steps 30000 --learning_rate 1e-4 --train_batch_size 4 \ 5 | --validation_steps 1000 --checkpointing_steps 5000 --checkpoints_total_limit 1 \ 6 | --pretrained_model_name_or_path PATH/PixArt-XL-2-512x512 \ 7 | --output_dir=./outputs/merge_base_depth_b32_30k 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /train_scripts/train_merge_l_depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | accelerate launch --num_processes=8 --main_process_port=36661 ./unideg/train_merge_large_depth.py \ 4 | --mixed_precision=bf16 --train_batch_size=1 --dataloader_num_workers=8 --pretrained_model_name_or_path PATH/FLUX.1-dev \ 5 | --gradient_accumulation_steps=4 --use_8bit_adam --learning_rate=3e-4 --lr_scheduler="linear" --checkpoints_total_limit 1 \ 6 | --max_train_steps=30000 --validation_steps 5000 --checkpointing_steps 5000 --output_dir=./outputs/merge_large_depth_b32_30k 7 | 8 | 9 | 10 | 11 | --------------------------------------------------------------------------------