├── LICENSE ├── README.md ├── assets ├── faces │ ├── 00000014.png │ └── 00002386.png ├── pipeline.png └── real-photos │ └── dog.png ├── configs ├── dataset │ ├── dehaze_train.yaml │ ├── dehaze_val.yaml │ ├── derain_train.yaml │ ├── derain_val.yaml │ ├── enlight_train.yaml │ ├── enlight_val.yaml │ ├── face_test_inp.yaml │ ├── face_test_lq.yaml │ ├── face_train.yaml │ ├── face_val.yaml │ ├── general_deg_codeformer_train.yaml │ ├── general_deg_codeformer_val.yaml │ ├── general_deg_realesrgan_train.yaml │ ├── general_deg_realesrgan_val.yaml │ ├── general_test_lq.yaml │ ├── ocr_train.yaml │ └── ocr_val.yaml ├── model │ ├── cldm.yaml │ ├── cldm_bsr_eval.yaml │ ├── cldm_eval.yaml │ └── swinir.yaml ├── test_cldm.yaml ├── test_cldm_general.yaml ├── train_cldm.yaml ├── train_cldm_general.yaml └── train_swinir.yaml ├── dataset ├── __pycache__ │ ├── batch_transform.cpython-38.pyc │ ├── codeformer.cpython-38.pyc │ └── data_module.cpython-38.pyc ├── batch_transform.py ├── codeformer.py ├── data_module.py ├── realesrgan.py └── test.py ├── evaluate.py ├── inference.py ├── inference_bfr.py ├── inference_bsr.py ├── ldm ├── __pycache__ │ ├── util.cpython-38.pyc │ ├── util.cpython-39.pyc │ ├── xformers_state.cpython-38.pyc │ └── xformers_state.cpython-39.pyc ├── data │ ├── __init__.py │ └── util.py ├── models │ ├── __pycache__ │ │ ├── autoencoder.cpython-38.pyc │ │ └── autoencoder.cpython-39.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── ddim.cpython-38.pyc │ │ ├── ddim.cpython-39.pyc │ │ ├── ddpm.cpython-38.pyc │ │ └── ddpm.cpython-39.pyc │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ ├── attention.cpython-39.pyc │ │ ├── ema.cpython-38.pyc │ │ └── ema.cpython-39.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── model.cpython-39.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ ├── openaimodel.cpython-39.pyc │ │ │ ├── util.cpython-38.pyc │ │ │ └── util.cpython-39.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── upscaling.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── distributions.cpython-38.pyc │ │ │ └── distributions.cpython-39.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ └── modules.cpython-39.pyc │ │ └── modules.py │ └── midas │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ │ └── utils.py ├── util.py └── xformers_state.py ├── model ├── __pycache__ │ ├── callbacks.cpython-38.pyc │ ├── cldm.cpython-38.pyc │ ├── cldm_bsr.cpython-39.pyc │ ├── cond_fn.cpython-38.pyc │ ├── mixins.cpython-38.pyc │ ├── mixins.cpython-39.pyc │ ├── spaced_sampler.cpython-38.pyc │ ├── swinir.cpython-38.pyc │ └── swinir.cpython-39.pyc ├── callbacks.py ├── cldm.py ├── cldm_bsr.py ├── cond_fn.py ├── mixins.py ├── spaced_sampler.py └── swinir.py ├── requirements.txt ├── scripts ├── inference_stage1.py ├── make_file_list.py ├── make_list_celea.py ├── make_stage2_init_weight.py ├── merge_img.py ├── metrics.py ├── rainy.py └── sample_dataset.py ├── test.py ├── train.py ├── utils ├── __pycache__ │ ├── common.cpython-38.pyc │ ├── common.cpython-39.pyc │ ├── degradation.cpython-38.pyc │ ├── face_restoration_helper.cpython-38.pyc │ ├── file.cpython-310.pyc │ ├── file.cpython-38.pyc │ ├── file.cpython-39.pyc │ ├── metrics.cpython-38.pyc │ ├── metrics.cpython-39.pyc │ ├── process.cpython-38.pyc │ └── utils.cpython-38.pyc ├── common.py ├── degradation.py ├── face_restoration_helper.py ├── file.py ├── image │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── align_color.cpython-38.pyc │ │ ├── align_color.cpython-39.pyc │ │ ├── common.cpython-38.pyc │ │ ├── common.cpython-39.pyc │ │ ├── diffjpeg.cpython-38.pyc │ │ ├── diffjpeg.cpython-39.pyc │ │ ├── usm_sharp.cpython-38.pyc │ │ └── usm_sharp.cpython-39.pyc │ ├── align_color.py │ ├── common.py │ ├── diffjpeg.py │ └── usm_sharp.py ├── metrics.py ├── pickout_img.py ├── process.py ├── realesrgan │ ├── __pycache__ │ │ ├── realesrganer.cpython-38.pyc │ │ └── rrdbnet.cpython-38.pyc │ ├── realesrganer.py │ └── rrdbnet.py ├── torchinterp1d-master.zip ├── torchinterp1d │ ├── .gitignore │ ├── .gitmodules │ ├── LICENSE │ ├── README.md │ ├── examples │ │ └── test.py │ ├── setup.py │ └── torchinterp1d │ │ ├── __init__.py │ │ └── interp1d.py └── utils.py └── weights └── null_token_1024.pth /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yixuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FlowIE: Efficient Image Enhancement via Rectified Flow (CVPR 2024) 2 | 3 | > [Yixuan Zhu](https://eternalevan.github.io/)\*, [Wenliang Zhao](https://wl-zhao.github.io/)\* $\dagger$, [Ao Li](https://rammusleo.github.io/), [Yansong Tang](https://andytang15.github.io/), [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu](http://ivg.au.tsinghua.edu.cn/Jiwen_Lu/) $\ddagger$ 4 | > 5 | > \* Equal contribution   $\dagger$ Project leader   $\ddagger$ Corresponding author 6 | 7 | [**[Paper]**](https://arxiv.org/abs/2406.00508) 8 | 9 | The repository contains the official implementation for the paper "FlowIE: Efficient Image Enhancement via Rectified Flow" (**CVPR 2024, oral presentation**). 10 | 11 | FlowIE is a simple yet highly effective **Flow**-based **I**mage **E**nhancement framework that estimates straight-line paths from an elementary distribution to high-quality images. 12 | ## 📋 To-Do List 13 | 14 | * [x] Release model and inference code. 15 | * [x] Release code for training dataloader. 16 | 17 | 18 | ## 💡 Pipeline 19 | 20 | ![](./assets/pipeline.png) 21 | 22 | 24 | 25 | 26 | ## 😀Quick Start 27 | ### ⚙️ 1. Installation 28 | 29 | We recommend you to use an [Anaconda](https://www.anaconda.com/) virtual environment. If you have installed Anaconda, run the following commands to create and activate a virtual environment. 30 | ``` bash 31 | conda env create -f requirements.txt 32 | conda activate FlowIE 33 | ``` 34 | ### 📑 2. Modify the lora configuration 35 | Since we use `MemoryEfficientCrossAttention` to accelerate the inference process, we need to slightly modify the `lora.py` in lora_diffusion package, which could be done in 2 minutes: 36 | - (1) Locate the `lora.py` file in the package directory. You can easily find this file by using the "go to definition" button in Line 4 of the `./model/cldm.py` file. 37 | - (2) Make the following modifications to Lines 159-161 in `lora.py`: 38 | 39 | Original Code: 40 | ```python 41 | UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} 42 | UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} 43 | ``` 44 | 45 | Modified Code: 46 | ```python 47 | UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU", "MemoryEfficientCrossAttention"} 48 | UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU", "MemoryEfficientCrossAttention", "ResBlock"} 49 | ``` 50 | 51 | ### 💾 2. Data Preparation 52 | 53 | We prepare the data in a samilar way as [GFPGAN](https://xinntao.github.io/projects/gfpgan) & [DiffBIR](https://github.com/XPixelGroup/DiffBIR). We list the datasets for BFR and BSR as follows: 54 | 55 | For BFR evaluation, please refer to [here](https://xinntao.github.io/projects/gfpgan) for *BFR-test datasets*, which include *CelebA-Test*, *CelebChild-Test* and *LFW-Test*. The *WIDER-Test* can be found in [here](https://drive.google.com/file/d/1g05U86QGqnlN_v9SRRKDTU8033yvQNEa/view). For BFR training, please download the [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset). 56 | 57 | For BSR, we utilize [ImageNet](https://www.image-net.org/index.php) for training. For evaluation, you can refer to [BSRGAN](https://github.com/cszn/BSRGAN/tree/main/testsets) for *RealSRSet*. 58 | 59 | To prepare the training list, you need to simply run the script: 60 | ```bash 61 | python ./scripts/make_file_list.py --img_folder /data/ILSVRC2012 --save_folder ./dataset/list/imagenet 62 | python ./scripts/make_file_list.py --img_folder /data/FFHQ --save_folder ./dataset/list/ffhq 63 | ``` 64 | The file list looks like this: 65 | ```bash 66 | /path/to/image_1.png 67 | /path/to/image_2.png 68 | /path/to/image_3.png 69 | ... 70 | ``` 71 | ### 🗂️ 3. Download Checkpoints 72 | 73 | Please download our pretrained checkpoints from [this link](https://cloud.tsinghua.edu.cn/d/4fa2a0880a9243999561/) and put them under `./weights`. The file directory should be: 74 | 75 | ``` 76 | |-- checkpoints 77 | |--|-- FlowIE_bfr_v1.ckpt 78 | |--|-- FlowIE_bsr_v1.ckpt 79 | ... 80 | ``` 81 | 82 | ### 📊 4. Test & Evaluation 83 | 84 | You can test FlowIE with following commands: 85 | - **Evaluation for BFR** 86 | ```bash 87 | python inference_bfr.py --ckpt ./weights/FlowIE_bfr_v1.ckpt --has_aligned --input /data/celeba_512_validation_lq/ --output ./outputs/bfr_exp --has_aligned 88 | ``` 89 | - **Evaluation for BSR** 90 | ```bash 91 | python inference_bsr.py --ckpt ./weights/FlowIE_bsr_v1.ckpt --input /data/testdata/ --output ./outputs/bsr_exp --sr_scale 4 92 | ``` 93 | - **Quick Test** 94 | 95 | For a quick test, we collect some test samples in `./assets`. You can run the demo for BFR: 96 | ```bash 97 | python inference_bfr.py --ckpt ./weights/FlowIE_bfr_v1.ckpt --input ./assets/faces --output ./outputs/demo 98 | ``` 99 | And for BSR: 100 | ```bash 101 | python inference_bsr.py --ckpt ./weights/FlowIE_bsr_v1.pth --input ./assets/real-photos/ --output ./outputs/bsr_exp --tiled --sr_scale 4 102 | ``` 103 | You can use `--tiled` for patch-based inference and use `--sr_scale` tp set the super-resolution scale, like 2 or 4. You can set `CUDA_VISIBLE_DEVICES=1` to choose the devices. 104 | 114 | 115 | The evaluation process can be done with one Nvidia GeForce RTX 3090 GPU (24GB VRAM). You can use more GPUs by specifying the GPU ids. 116 | 117 | ### 🔥 5. Training 118 | The key component in FlowIE is a path estimator tuned from [Stable Diffusion v2.1 base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). Please download it to `./weights`. Another part is the initial module, which can be found in [checkpoints](https://cloud.tsinghua.edu.cn/d/4fa2a0880a9243999561/). 119 | 120 | Before training, you also need to configure training-related information in `./configs/train_cldm.yaml`. Then run this command to start training: 121 | ```bash 122 | python train.py --config ./configs/train_cldm.yaml 123 | ``` 124 | 125 | 143 | ## 🫰 Acknowledgments 144 | 145 | We would like to express our sincere thanks to the author of [DiffBIR](https://github.com/XPixelGroup/DiffBIR) for the clear code base and quick response to our issues. 146 | 147 | We also thank [CodeFormer](https://github.com/sczhou/CodeFormer), [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) and [LoRA](https://github.com/cloneofsimo/lora), for our code is partially borrowing from them. 148 | 149 | The new version of FlowIE based on Denoising Transformer (DiT) structure will be released soon! Thanks the newest works of DiTs, including [PixART](https://github.com/PixArt-alpha/PixArt-sigma) and [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium). 150 | 151 | ## 🔖 Citation 152 | Please cite us if our work is useful for your research. 153 | 154 | ``` 155 | @misc{zhu2024flowie, 156 | title={FlowIE: Efficient Image Enhancement via Rectified Flow}, 157 | author={Yixuan Zhu and Wenliang Zhao and Ao Li and Yansong Tang and Jie Zhou and Jiwen Lu}, 158 | year={2024}, 159 | eprint={2406.00508}, 160 | archivePrefix={arXiv}, 161 | primaryClass={cs.CV} 162 | } 163 | ``` 164 | ## 🔑 License 165 | 166 | This code is distributed under an [MIT LICENSE](./LICENSE). 167 | -------------------------------------------------------------------------------- /assets/faces/00000014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/faces/00000014.png -------------------------------------------------------------------------------- /assets/faces/00002386.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/faces/00002386.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/pipeline.png -------------------------------------------------------------------------------- /assets/real-photos/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/real-photos/dog.png -------------------------------------------------------------------------------- /configs/dataset/dehaze_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Dehaze 3 | params: 4 | # Path to the file list. 5 | is_val: False 6 | out_size: 512 7 | crop_type: random 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 8 20 | shuffle: true 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/dehaze_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Dehaze 3 | params: 4 | # Path to the file list. 5 | is_val: True 6 | out_size: 512 7 | crop_type: center 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 2 20 | shuffle: false 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/derain_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Derain 3 | params: 4 | # Path to the file list. 5 | is_val: False 6 | out_size: 512 7 | crop_type: random 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 4 20 | shuffle: true 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/derain_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Derain 3 | params: 4 | # Path to the file list. 5 | is_val: True 6 | out_size: 512 7 | crop_type: center 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 2 20 | shuffle: false 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/enlight_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Enlight 3 | params: 4 | # Path to the file list. 5 | is_val: False 6 | out_size: 512 7 | crop_type: random 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 8 20 | shuffle: true 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/enlight_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Enlight 3 | params: 4 | # Path to the file list. 5 | is_val: True 6 | out_size: 512 7 | crop_type: center 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 2 20 | shuffle: false 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/face_test_inp.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset_Mask 3 | params: 4 | # Path to the file list. 5 | #file_list: /home/user001/zwl/data/FFHQ512/train.list 6 | file_list: /home/user001/zwl/zyx/Diffbir/10955/meta/inp.list 7 | # wider :34.391226832176 lfw:38.66 8 | out_size: 512 9 | crop_type: none 10 | use_hflip: False 11 | 12 | blur_kernel_size: 41 13 | kernel_list: ['iso', 'aniso'] 14 | kernel_prob: [0.5, 0.5] 15 | blur_sigma: [0.1, 10] 16 | downsample_range: [0.8, 8] 17 | noise_range: [0, 20] 18 | jpeg_range: [60, 100] 19 | 20 | data_loader: 21 | batch_size: 4 22 | shuffle: false 23 | num_workers: 16 24 | drop_last: true 25 | 26 | batch_transform: 27 | target: dataset.batch_transform.IdentityBatchTransform 28 | -------------------------------------------------------------------------------- /configs/dataset/face_test_lq.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDatasetLQ 3 | params: 4 | # Path to the file list. 5 | # file_list: /data1/zyx/FFHQ512/test.list 6 | hq_list: /home/user001/zwl/zyx/Diffbir/test_list/1024.list 7 | lq_list: /home/user001/zwl/zyx/Diffbir/test_list/1024.list 8 | #lq_list: /home/user001/zwl/zyx/Diffbir/test_list/inp-z1.list 9 | #hq_list: /home/user001/zwl/data/CelebChild-Test/child.list 10 | #lq_list: /home/user001/zwl/data/CelebChild-Test/child.list 11 | # wider :34.391226832176 lfw:38.66 12 | out_size: 512 13 | crop_type: none 14 | use_hflip: False 15 | 16 | blur_kernel_size: 41 17 | kernel_list: ['iso', 'aniso'] 18 | kernel_prob: [0.5, 0.5] 19 | blur_sigma: [0.1, 10] 20 | downsample_range: [0.8, 8] 21 | noise_range: [0, 20] 22 | jpeg_range: [60, 100] 23 | 24 | data_loader: 25 | batch_size: 2 26 | shuffle: false 27 | num_workers: 16 28 | drop_last: true 29 | 30 | batch_transform: 31 | target: dataset.batch_transform.IdentityBatchTransform 32 | -------------------------------------------------------------------------------- /configs/dataset/face_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset 3 | params: 4 | # Path to the file list. 5 | file_list: /home/user001/zwl/data/FFHQ512/train.list 6 | out_size: 512 7 | crop_type: none 8 | use_hflip: true 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 10] 14 | downsample_range: [0.8, 8] 15 | noise_range: [0, 20] 16 | jpeg_range: [60, 100] 17 | 18 | # color_jitter_prob: 0.3 19 | # color_jitter_shift: 20 20 | # color_jitter_pt_prob: 0.3 21 | # gray_prob: 0.01 22 | 23 | data_loader: 24 | batch_size: 4 25 | shuffle: true 26 | num_workers: 16 27 | drop_last: true 28 | 29 | batch_transform: 30 | target: dataset.batch_transform.IdentityBatchTransform 31 | -------------------------------------------------------------------------------- /configs/dataset/face_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset 3 | params: 4 | # Path to the file list. 5 | file_list: /home/user001/zwl/data/FFHQ512/val.list 6 | #file_list: /data1/zyx/celeba_512_validation_list_files/val_lq.list 7 | 8 | out_size: 512 9 | crop_type: none 10 | use_hflip: False 11 | 12 | blur_kernel_size: 41 13 | kernel_list: ['iso', 'aniso'] 14 | kernel_prob: [0.5, 0.5] 15 | blur_sigma: [0.1, 10] 16 | downsample_range: [0.8, 8] 17 | noise_range: [0, 20] 18 | jpeg_range: [60, 100] 19 | 20 | # color_jitter_prob: 0.3 21 | # color_jitter_shift: 20 22 | # color_jitter_pt_prob: 0.3 23 | # gray_prob: 0.01 24 | 25 | 26 | data_loader: 27 | batch_size: 3 28 | shuffle: false 29 | num_workers: 16 30 | drop_last: true 31 | 32 | batch_transform: 33 | target: dataset.batch_transform.IdentityBatchTransform 34 | -------------------------------------------------------------------------------- /configs/dataset/general_deg_codeformer_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset 3 | params: 4 | # Path to the file list. 5 | file_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/train_all.list 6 | out_size: 512 7 | crop_type: center 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 4 20 | shuffle: true 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/general_deg_codeformer_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDataset 3 | params: 4 | # Path to the file list. 5 | file_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/val.list 6 | out_size: 512 7 | crop_type: center 8 | use_hflip: True 9 | 10 | blur_kernel_size: 41 11 | kernel_list: ['iso', 'aniso'] 12 | kernel_prob: [0.5, 0.5] 13 | blur_sigma: [0.1, 12] 14 | downsample_range: [1, 12] 15 | noise_range: [0, 15] 16 | jpeg_range: [30, 100] 17 | 18 | data_loader: 19 | batch_size: 2 20 | shuffle: false 21 | num_workers: 16 22 | drop_last: true 23 | 24 | batch_transform: 25 | target: dataset.batch_transform.IdentityBatchTransform 26 | -------------------------------------------------------------------------------- /configs/dataset/general_deg_realesrgan_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.realesrgan.RealESRGANDataset 3 | params: 4 | # Path to the file list. 5 | file_list: 6 | out_size: 512 7 | crop_type: center 8 | 9 | use_hflip: false 10 | use_rot: false 11 | 12 | blur_kernel_size: 21 13 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 14 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 15 | sinc_prob: 0.1 16 | blur_sigma: [0.2, 3] 17 | betag_range: [0.5, 4] 18 | betap_range: [1, 2] 19 | 20 | blur_kernel_size2: 21 21 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 22 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 23 | sinc_prob2: 0.1 24 | blur_sigma2: [0.2, 1.5] 25 | betag_range2: [0.5, 4] 26 | betap_range2: [1, 2] 27 | 28 | final_sinc_prob: 0.8 29 | 30 | data_loader: 31 | batch_size: 32 32 | shuffle: true 33 | num_workers: 16 34 | prefetch_factor: 2 35 | drop_last: true 36 | 37 | batch_transform: 38 | target: dataset.batch_transform.RealESRGANBatchTransform 39 | params: 40 | use_sharpener: false 41 | resize_hq: false 42 | # Queue size of training pool, this should be multiples of batch_size. 43 | queue_size: 256 44 | # the first degradation process 45 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 46 | resize_range: [0.15, 1.5] 47 | gaussian_noise_prob: 0.5 48 | noise_range: [1, 30] 49 | poisson_scale_range: [0.05, 3] 50 | gray_noise_prob: 0.4 51 | jpeg_range: [30, 95] 52 | 53 | # the second degradation process 54 | stage2_scale: 4 55 | second_blur_prob: 0.8 56 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 57 | resize_range2: [0.3, 1.2] 58 | gaussian_noise_prob2: 0.5 59 | noise_range2: [1, 25] 60 | poisson_scale_range2: [0.05, 2.5] 61 | gray_noise_prob2: 0.4 62 | jpeg_range2: [30, 95] 63 | -------------------------------------------------------------------------------- /configs/dataset/general_deg_realesrgan_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.realesrgan.RealESRGANDataset 3 | params: 4 | # Path to the file list. 5 | file_list: 6 | out_size: 512 7 | crop_type: center 8 | 9 | use_hflip: false 10 | use_rot: false 11 | 12 | blur_kernel_size: 21 13 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 14 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 15 | sinc_prob: 0.1 16 | blur_sigma: [0.2, 3] 17 | betag_range: [0.5, 4] 18 | betap_range: [1, 2] 19 | 20 | blur_kernel_size2: 21 21 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 22 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 23 | sinc_prob2: 0.1 24 | blur_sigma2: [0.2, 1.5] 25 | betag_range2: [0.5, 4] 26 | betap_range2: [1, 2] 27 | 28 | final_sinc_prob: 0.8 29 | 30 | data_loader: 31 | batch_size: 32 32 | shuffle: false 33 | num_workers: 16 34 | prefetch_factor: 2 35 | drop_last: true 36 | 37 | batch_transform: 38 | target: dataset.batch_transform.RealESRGANBatchTransform 39 | params: 40 | use_sharpener: false 41 | resize_hq: false 42 | # Queue size of training pool, this should be multiples of batch_size. 43 | queue_size: 256 44 | # the first degradation process 45 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 46 | resize_range: [0.15, 1.5] 47 | gaussian_noise_prob: 0.5 48 | noise_range: [1, 30] 49 | poisson_scale_range: [0.05, 3] 50 | gray_noise_prob: 0.4 51 | jpeg_range: [30, 95] 52 | 53 | # the second degradation process 54 | stage2_scale: 4 55 | second_blur_prob: 0.8 56 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 57 | resize_range2: [0.3, 1.2] 58 | gaussian_noise_prob2: 0.5 59 | noise_range2: [1, 25] 60 | poisson_scale_range2: [0.05, 2.5] 61 | gray_noise_prob2: 0.4 62 | jpeg_range2: [30, 95] 63 | -------------------------------------------------------------------------------- /configs/dataset/general_test_lq.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDatasetLQ 3 | params: 4 | # Path to the file list. 5 | # file_list: /data1/zyx/FFHQ512/test.list 6 | #lq_list: /home/user001/zwl/zyx/Diffbir/test_list/custom-1.list 7 | lq_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/custom.list 8 | #hq_list: /home/user001/zwl/zyx/Diffbir/test_list/custom-1.list 9 | hq_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/custom.list 10 | # wider :34.391226832176 lfw:38.66 11 | out_size: 512 12 | crop_type: center 13 | use_hflip: False 14 | 15 | blur_kernel_size: 41 16 | kernel_list: ['iso', 'aniso'] 17 | kernel_prob: [0.5, 0.5] 18 | blur_sigma: [0.1, 10] 19 | downsample_range: [0.8, 8] 20 | noise_range: [0, 20] 21 | jpeg_range: [60, 100] 22 | 23 | data_loader: 24 | batch_size: 4 25 | shuffle: false 26 | num_workers: 16 27 | drop_last: true 28 | 29 | batch_transform: 30 | target: dataset.batch_transform.IdentityBatchTransform 31 | -------------------------------------------------------------------------------- /configs/dataset/ocr_train.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDatasetLQ 3 | params: 4 | # Path to the file list. 5 | hq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/train_hq.list 6 | lq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/train_lq.list 7 | out_size: 512 8 | crop_type: random 9 | use_hflip: True 10 | 11 | blur_kernel_size: 41 12 | kernel_list: ['iso', 'aniso'] 13 | kernel_prob: [0.5, 0.5] 14 | blur_sigma: [0.1, 12] 15 | downsample_range: [1, 12] 16 | noise_range: [0, 15] 17 | jpeg_range: [30, 100] 18 | 19 | data_loader: 20 | batch_size: 4 21 | shuffle: true 22 | num_workers: 16 23 | drop_last: true 24 | 25 | batch_transform: 26 | target: dataset.batch_transform.IdentityBatchTransform 27 | -------------------------------------------------------------------------------- /configs/dataset/ocr_val.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | target: dataset.codeformer.CodeformerDatasetLQ 3 | params: 4 | # Path to the file list. 5 | hq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/val_hq.list 6 | lq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/val_lq.list 7 | out_size: 512 8 | crop_type: center 9 | use_hflip: True 10 | 11 | blur_kernel_size: 41 12 | kernel_list: ['iso', 'aniso'] 13 | kernel_prob: [0.5, 0.5] 14 | blur_sigma: [0.1, 12] 15 | downsample_range: [1, 12] 16 | noise_range: [0, 15] 17 | jpeg_range: [30, 100] 18 | 19 | data_loader: 20 | batch_size: 2 21 | shuffle: false 22 | num_workers: 16 23 | drop_last: true 24 | 25 | batch_transform: 26 | target: dataset.batch_transform.IdentityBatchTransform 27 | -------------------------------------------------------------------------------- /configs/model/cldm.yaml: -------------------------------------------------------------------------------- 1 | #target: model.cldm.ControlLDM 2 | target: model.cldm.Reflow_ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | sd_locked: True 21 | only_mid_control: False 22 | # Learning rate. 23 | learning_rate: 1e-4 24 | lora_rank: 4 25 | 26 | control_stage_config: 27 | target: model.cldm.ControlNet 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | hint_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | unet_config: 45 | target: model.cldm.ControlledUnetModel 46 | params: 47 | use_checkpoint: false 48 | image_size: 32 # unused 49 | in_channels: 4 50 | out_channels: 4 51 | model_channels: 320 52 | attention_resolutions: [ 4, 2, 1 ] 53 | num_res_blocks: 2 54 | channel_mult: [ 1, 2, 4, 4 ] 55 | num_head_channels: 64 # need to fix for flash-attn 56 | use_spatial_transformer: True 57 | use_linear_in_transformer: True 58 | transformer_depth: 1 59 | context_dim: 1024 60 | legacy: False 61 | 62 | first_stage_config: 63 | use_fp16: False 64 | target: ldm.models.autoencoder.AutoencoderKL 65 | params: 66 | embed_dim: 4 67 | monitor: val/rec_loss 68 | ddconfig: 69 | #attn_type: "vanilla-xformers" 70 | double_z: true 71 | z_channels: 4 72 | resolution: 256 73 | in_channels: 3 74 | out_ch: 3 75 | ch: 128 76 | ch_mult: 77 | - 1 78 | - 2 79 | - 4 80 | - 4 81 | num_res_blocks: 2 82 | attn_resolutions: [] 83 | dropout: 0.0 84 | lossconfig: 85 | target: torch.nn.Identity 86 | 87 | cond_stage_config: 88 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 89 | params: 90 | freeze: True 91 | layer: "penultimate" 92 | 93 | preprocess_config: 94 | target: model.swinir.SwinIR 95 | params: 96 | img_size: 64 97 | patch_size: 1 98 | in_chans: 3 99 | embed_dim: 180 100 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 101 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 102 | window_size: 8 103 | mlp_ratio: 2 104 | sf: 8 105 | img_range: 1.0 106 | upsampler: "nearest+conv" 107 | resi_connection: "1conv" 108 | unshuffle: True 109 | unshuffle_scale: 8 110 | -------------------------------------------------------------------------------- /configs/model/cldm_bsr_eval.yaml: -------------------------------------------------------------------------------- 1 | #target: model.cldm.ControlLDM 2 | target: model.cldm_bsr.Reflow_ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | sd_locked: False 21 | only_mid_control: False 22 | # Learning rate. 23 | learning_rate: 1e-4 24 | lora_rank: 64 25 | 26 | control_stage_config: 27 | target: model.cldm_bsr.ControlNet 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | hint_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | use_fp16: False 44 | 45 | unet_config: 46 | target: model.cldm_bsr.ControlledUnetModel 47 | params: 48 | use_checkpoint: false 49 | image_size: 32 # unused 50 | in_channels: 4 51 | out_channels: 4 52 | model_channels: 320 53 | attention_resolutions: [ 4, 2, 1 ] 54 | num_res_blocks: 2 55 | channel_mult: [ 1, 2, 4, 4 ] 56 | num_head_channels: 64 # need to fix for flash-attn 57 | use_spatial_transformer: True 58 | use_linear_in_transformer: True 59 | transformer_depth: 1 60 | context_dim: 1024 61 | legacy: False 62 | use_fp16: False 63 | 64 | first_stage_config: 65 | use_fp16: True 66 | model_id: stabilityai/sd-vae-ft-ema 67 | target: ldm.models.autoencoder.AutoencoderKL 68 | params: 69 | embed_dim: 4 70 | monitor: val/rec_loss 71 | ddconfig: 72 | #attn_type: "vanilla-xformers" 73 | double_z: true 74 | z_channels: 4 75 | resolution: 256 76 | in_channels: 3 77 | out_ch: 3 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 2 82 | - 4 83 | - 4 84 | num_res_blocks: 2 85 | attn_resolutions: [] 86 | dropout: 0.0 87 | lossconfig: 88 | target: torch.nn.Identity 89 | 90 | cond_stage_config: 91 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 92 | params: 93 | freeze: True 94 | layer: "penultimate" 95 | 96 | preprocess_config: 97 | target: model.swinir.SwinIR 98 | params: 99 | img_size: 64 100 | patch_size: 1 101 | in_chans: 3 102 | embed_dim: 180 103 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 104 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 105 | window_size: 8 106 | mlp_ratio: 2 107 | sf: 8 108 | img_range: 1.0 109 | upsampler: "nearest+conv" 110 | resi_connection: "1conv" 111 | unshuffle: True 112 | unshuffle_scale: 8 113 | -------------------------------------------------------------------------------- /configs/model/cldm_eval.yaml: -------------------------------------------------------------------------------- 1 | #target: model.cldm.ControlLDM 2 | target: model.cldm.Reflow_ControlLDM 3 | params: 4 | linear_start: 0.00085 5 | linear_end: 0.0120 6 | num_timesteps_cond: 1 7 | log_every_t: 200 8 | timesteps: 1000 9 | first_stage_key: "jpg" 10 | cond_stage_key: "txt" 11 | control_key: "hint" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | sd_locked: True 21 | only_mid_control: False 22 | # Learning rate. 23 | learning_rate: 1e-4 24 | lora_rank: 4 25 | 26 | control_stage_config: 27 | target: model.cldm.ControlNet 28 | params: 29 | use_checkpoint: True 30 | image_size: 32 # unused 31 | in_channels: 4 32 | hint_channels: 4 33 | model_channels: 320 34 | attention_resolutions: [ 4, 2, 1 ] 35 | num_res_blocks: 2 36 | channel_mult: [ 1, 2, 4, 4 ] 37 | num_head_channels: 64 # need to fix for flash-attn 38 | use_spatial_transformer: True 39 | use_linear_in_transformer: True 40 | transformer_depth: 1 41 | context_dim: 1024 42 | legacy: False 43 | 44 | unet_config: 45 | target: model.cldm.ControlledUnetModel 46 | params: 47 | use_checkpoint: false 48 | image_size: 32 # unused 49 | in_channels: 4 50 | out_channels: 4 51 | model_channels: 320 52 | attention_resolutions: [ 4, 2, 1 ] 53 | num_res_blocks: 2 54 | channel_mult: [ 1, 2, 4, 4 ] 55 | num_head_channels: 64 # need to fix for flash-attn 56 | use_spatial_transformer: True 57 | use_linear_in_transformer: True 58 | transformer_depth: 1 59 | context_dim: 1024 60 | legacy: False 61 | 62 | first_stage_config: 63 | target: ldm.models.autoencoder.AutoencoderKL 64 | params: 65 | embed_dim: 4 66 | monitor: val/rec_loss 67 | ddconfig: 68 | #attn_type: "vanilla-xformers" 69 | double_z: true 70 | z_channels: 4 71 | resolution: 256 72 | in_channels: 3 73 | out_ch: 3 74 | ch: 128 75 | ch_mult: 76 | - 1 77 | - 2 78 | - 4 79 | - 4 80 | num_res_blocks: 2 81 | attn_resolutions: [] 82 | dropout: 0.0 83 | lossconfig: 84 | target: torch.nn.Identity 85 | 86 | cond_stage_config: 87 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 88 | params: 89 | freeze: True 90 | layer: "penultimate" 91 | 92 | preprocess_config: 93 | target: model.swinir.SwinIR 94 | params: 95 | img_size: 64 96 | patch_size: 1 97 | in_chans: 3 98 | embed_dim: 180 99 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 100 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 101 | window_size: 8 102 | mlp_ratio: 2 103 | sf: 8 104 | img_range: 1.0 105 | upsampler: "nearest+conv" 106 | resi_connection: "1conv" 107 | unshuffle: True 108 | unshuffle_scale: 8 109 | -------------------------------------------------------------------------------- /configs/model/swinir.yaml: -------------------------------------------------------------------------------- 1 | target: model.swinir.SwinIR 2 | params: 3 | img_size: 64 4 | patch_size: 1 5 | in_chans: 3 6 | embed_dim: 180 7 | depths: [6, 6, 6, 6, 6, 6, 6, 6] 8 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6] 9 | window_size: 8 10 | mlp_ratio: 2 11 | sf: 8 12 | img_range: 1.0 13 | upsampler: "nearest+conv" 14 | resi_connection: "1conv" 15 | unshuffle: True 16 | unshuffle_scale: 8 17 | 18 | hq_key: jpg 19 | lq_key: hint 20 | # Learning rate. 21 | learning_rate: 1e-4 22 | weight_decay: 0 23 | -------------------------------------------------------------------------------- /configs/test_cldm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | target: dataset.data_module.BIRDataModule 3 | params: 4 | # Path to training set configuration file. 5 | train_config: ./configs/dataset/face_train.yaml 6 | # Path to validation set configuration file. 7 | val_config: ./configs/dataset/face_test_inp.yaml 8 | 9 | model: 10 | # You can set learning rate in the following configuration file. 11 | config: configs/model/cldm_eval.yaml 12 | # Path to the checkpoints or weights you want to resume. At the begining, 13 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py. 14 | 15 | resume: ./work_dirs/inpainting4/lightning_logs/version_0/checkpoints/step=79999.ckpt 16 | lightning: 17 | seed: 231 18 | 19 | trainer: 20 | accelerator: ddp 21 | precision: 32 22 | # Indices of GPUs used for training. 23 | gpus: [0] 24 | # Path to save logs and checkpoints. 25 | default_root_dir: ./work_dirs/testcelebamaskhq_1 26 | # Max number of training steps (batches). 27 | max_steps: 250001 28 | # Validation frequency in terms of training steps. 29 | val_check_interval: 500 30 | log_every_n_steps: 50 31 | # Accumulate gradients from multiple batches so as to increase batch size. 32 | accumulate_grad_batches: 1 33 | 34 | callbacks: 35 | - target: model.callbacks.ImageLogger 36 | params: 37 | # Log frequency of image logger. 38 | log_every_n_steps: 250 39 | max_images_each_step: 4 40 | log_images_kwargs: ~ 41 | 42 | - target: model.callbacks.ModelCheckpoint 43 | params: 44 | # Frequency of saving checkpoints. 45 | every_n_train_steps: 5000 46 | save_top_k: -1 47 | filename: "{step}" 48 | -------------------------------------------------------------------------------- /configs/test_cldm_general.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | target: dataset.data_module.BIRDataModule 3 | params: 4 | # Path to training set configuration file. 5 | train_config: ./configs/dataset/derain_train.yaml 6 | 7 | # Path to validation set configuration file. 8 | val_config: ./configs/dataset/derain_val.yaml 9 | 10 | model: 11 | # You can set learning rate in the following configuration file. 12 | config: configs/model/cldm_eval.yaml 13 | # Path to the checkpoints or weights you want to resume. At the begining, 14 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py. 15 | resume: ./work_dirs/exp1/lightning_logs/version_1/checkpoints/step=19999.ckpt 16 | 17 | lightning: 18 | seed: 231 19 | 20 | trainer: 21 | accelerator: ddp 22 | precision: 32 23 | # Indices of GPUs used for training. 24 | gpus: [0] 25 | # Path to save logs and checkpoints. 26 | default_root_dir: ./work_dirs/ 27 | # Max number of training steps (batches). 28 | max_steps: 250001 29 | # Validation frequency in terms of training steps. 30 | val_check_interval: 500 31 | log_every_n_steps: 50 32 | # Accumulate gradients from multiple batches so as to increase batch size. 33 | accumulate_grad_batches: 1 34 | 35 | callbacks: 36 | - target: model.callbacks.ImageLogger 37 | params: 38 | # Log frequency of image logger. 39 | log_every_n_steps: 250 40 | max_images_each_step: 4 41 | log_images_kwargs: ~ 42 | 43 | - target: model.callbacks.ModelCheckpoint 44 | params: 45 | # Frequency of saving checkpoints. 46 | every_n_train_steps: 5000 47 | save_top_k: -1 48 | filename: "{step}" 49 | -------------------------------------------------------------------------------- /configs/train_cldm.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | target: dataset.data_module.BIRDataModule 3 | params: 4 | # Path to training set configuration file. 5 | train_config: 6 | # Path to validation set configuration file. 7 | val_config: 8 | 9 | model: 10 | # You can set learning rate in the following configuration file. 11 | config: configs/model/cldm.yaml 12 | # Path to the checkpoints or weights you want to resume. At the begining, 13 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py. 14 | resume: ./weights/v2-1_512-ema-pruned.ckpt 15 | 16 | lightning: 17 | seed: 231 18 | 19 | trainer: 20 | accelerator: ddp 21 | precision: 32 22 | # Indices of GPUs used for training. 23 | gpus: [1,2,3,4,5,6,7] 24 | # Path to save logs and checkpoints. 25 | default_root_dir: ./work_dirs/exp1 26 | # Max number of training steps (batches). 27 | max_steps: 250001 28 | # Validation frequency in terms of training steps. 29 | val_check_interval: 500 30 | log_every_n_steps: 50 31 | # Accumulate gradients from multiple batches so as to increase batch size. 32 | accumulate_grad_batches: 1 33 | 34 | callbacks: 35 | - target: model.callbacks.ImageLogger 36 | params: 37 | # Log frequency of image logger. 38 | log_every_n_steps: 250 39 | max_images_each_step: 4 40 | log_images_kwargs: ~ 41 | 42 | - target: model.callbacks.ModelCheckpoint 43 | params: 44 | # Frequency of saving checkpoints. 45 | every_n_train_steps: 3000 46 | save_top_k: -1 47 | filename: "{step}" 48 | -------------------------------------------------------------------------------- /configs/train_cldm_general.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | target: dataset.data_module.BIRDataModule 3 | params: 4 | # Path to training set configuration file. 5 | train_config: ./configs/dataset/derain_train.yaml 6 | # Path to validation set configuration file. 7 | val_config: ./configs/dataset/derain_val.yaml 8 | 9 | model: 10 | # You can set learning rate in the following configuration file. 11 | config: ./configs/model/cldm.yaml 12 | resume: ./work_dirs/exp0_general_all/lightning_logs/version_3/checkpoints/step=84999.ckpt 13 | 14 | lightning: 15 | seed: 231 16 | 17 | trainer: 18 | accelerator: ddp 19 | precision: 32 20 | # Indices of GPUs used for training. 21 | gpus: [1,2,3,4,5,6,7] 22 | # Path to save logs and checkpoints. 23 | default_root_dir: ./work_dirs/derain0 24 | # Max number of training steps (batches). 25 | max_steps: 250001 26 | # Validation frequency in terms of training steps. 27 | val_check_interval: 100 28 | log_every_n_steps: 250 29 | # Accumulate gradients from multiple batches so as to increase batch size. 30 | accumulate_grad_batches: 1 31 | 32 | callbacks: 33 | - target: model.callbacks.ImageLogger 34 | params: 35 | # Log frequency of image logger. 36 | log_every_n_steps: 50 37 | max_images_each_step: 4 38 | log_images_kwargs: ~ 39 | 40 | - target: model.callbacks.ModelCheckpoint 41 | params: 42 | # Frequency of saving checkpoints. 43 | every_n_train_steps: 5000 44 | save_top_k: -1 45 | filename: "{step}" 46 | -------------------------------------------------------------------------------- /configs/train_swinir.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | target: dataset.data_module.BIRDataModule 3 | params: 4 | # Path to training set configuration file. 5 | train_config: ./configs/dataset/dehaze_train.yaml 6 | # Path to validation set configuration file. 7 | val_config: ./configs/dataset/dehaze_val.yaml 8 | 9 | model: 10 | # You can set learning rate in the following configuration file. 11 | config: configs/model/swinir.yaml 12 | # Path to the checkpoints or weights you want to resume. 13 | resume: ./weights/general_swinir_v1.ckpt 14 | 15 | lightning: 16 | seed: 231 17 | 18 | trainer: 19 | accelerator: ddp 20 | precision: 32 21 | # Indices of GPUs used for training. 22 | gpus: [1,2,3,5,6,7] 23 | # Path to save logs and checkpoints. 24 | default_root_dir: ./work_dirs/swin 25 | # Max number of training steps (batches). 26 | max_steps: 150001 27 | # Validation frequency in terms of training steps. 28 | val_check_interval: 250 29 | # Log frequency of tensorboard logger. 30 | log_every_n_steps: 250 31 | # Accumulate gradients from multiple batches so as to increase batch size. 32 | accumulate_grad_batches: 1 33 | 34 | callbacks: 35 | - target: model.callbacks.ImageLogger 36 | params: 37 | # Log frequency of image logger. 38 | log_every_n_steps: 1000 39 | max_images_each_step: 4 40 | log_images_kwargs: ~ 41 | 42 | - target: model.callbacks.ModelCheckpoint 43 | params: 44 | # Frequency of saving checkpoints. 45 | every_n_train_steps: 10000 46 | save_top_k: -1 47 | filename: "{step}" 48 | -------------------------------------------------------------------------------- /dataset/__pycache__/batch_transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/batch_transform.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/codeformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/codeformer.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/data_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/data_module.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/data_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, Mapping 2 | 3 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS 4 | import pytorch_lightning as pl 5 | from torch.utils.data import DataLoader, Dataset 6 | from omegaconf import OmegaConf 7 | 8 | from utils.common import instantiate_from_config 9 | from dataset.batch_transform import BatchTransform, IdentityBatchTransform 10 | from torch.utils.data.distributed import DistributedSampler 11 | 12 | 13 | class BIRDataModule(pl.LightningDataModule): 14 | 15 | def __init__( 16 | self, 17 | train_config: str, 18 | val_config: str=None 19 | ) -> "BIRDataModule": 20 | super().__init__() 21 | self.train_config = OmegaConf.load(train_config) 22 | self.val_config = OmegaConf.load(val_config) if val_config else None 23 | 24 | def load_dataset(self, config: Mapping[str, Any]) -> Tuple[Dataset, BatchTransform]: 25 | dataset = instantiate_from_config(config["dataset"]) 26 | batch_transform = ( 27 | instantiate_from_config(config["batch_transform"]) 28 | if config.get("batch_transform") else IdentityBatchTransform() 29 | ) 30 | return dataset, batch_transform 31 | 32 | def setup(self, stage: str) -> None: 33 | if stage == "fit": 34 | self.train_dataset, self.train_batch_transform = self.load_dataset(self.train_config) 35 | if self.val_config: 36 | self.val_dataset, self.val_batch_transform = self.load_dataset(self.val_config) 37 | else: 38 | self.val_dataset, self.val_batch_transform = None, None 39 | else: 40 | raise NotImplementedError(stage) 41 | 42 | def train_dataloader(self) -> TRAIN_DATALOADERS: 43 | return DataLoader( 44 | dataset=self.train_dataset, **self.train_config["data_loader"] 45 | ) 46 | 47 | def val_dataloader(self) -> EVAL_DATALOADERS: 48 | if self.val_dataset is None: 49 | return None 50 | return DataLoader( 51 | dataset=self.val_dataset, **self.val_config["data_loader"] 52 | ) 53 | 54 | def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: 55 | self.trainer: pl.Trainer 56 | 57 | if self.trainer.training: 58 | return self.train_batch_transform(batch) 59 | elif self.trainer.validating or self.trainer.sanity_checking: 60 | return self.val_batch_transform(batch) 61 | else: 62 | raise RuntimeError( 63 | "Trainer state: \n" 64 | f"training: {self.trainer.training}\n" 65 | f"validating: {self.trainer.validating}\n" 66 | f"testing: {self.trainer.testing}\n" 67 | f"predicting: {self.trainer.predicting}\n" 68 | f"sanity_checking: {self.trainer.sanity_checking}" 69 | ) 70 | 71 | class BIRDataModuleDistributed(pl.LightningDataModule): 72 | 73 | def __init__( 74 | self, 75 | train_config: str, 76 | val_config: str=None 77 | ) -> "BIRDataModule": 78 | super().__init__() 79 | self.train_config = OmegaConf.load(train_config) 80 | self.val_config = OmegaConf.load(val_config) if val_config else None 81 | 82 | def load_dataset(self, config: Mapping[str, Any]) -> Tuple[Dataset, BatchTransform]: 83 | dataset = instantiate_from_config(config["dataset"]) 84 | batch_transform = ( 85 | instantiate_from_config(config["batch_transform"]) 86 | if config.get("batch_transform") else IdentityBatchTransform() 87 | ) 88 | return dataset, batch_transform 89 | 90 | def setup(self, stage: str) -> None: 91 | if stage == "fit": 92 | self.train_dataset, self.train_batch_transform = self.load_dataset(self.train_config) 93 | if self.val_config: 94 | self.val_dataset, self.val_batch_transform = self.load_dataset(self.val_config) 95 | else: 96 | self.val_dataset, self.val_batch_transform = None, None 97 | else: 98 | raise NotImplementedError(stage) 99 | 100 | def train_dataloader(self) -> TRAIN_DATALOADERS: 101 | return DataLoader( 102 | dataset=self.train_dataset, **self.train_config["data_loader"] 103 | ) 104 | 105 | def val_dataloader(self, rank) -> EVAL_DATALOADERS: 106 | if self.val_dataset is None: 107 | return None 108 | sampler = DistributedSampler(self.val_dataset, 109 | rank=rank,shuffle=False, drop_last=False) 110 | return DataLoader( 111 | dataset=self.val_dataset, sampler=sampler, **self.val_config["data_loader"] 112 | ) 113 | 114 | def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: 115 | self.trainer: pl.Trainer 116 | 117 | if self.trainer.training: 118 | return self.train_batch_transform(batch) 119 | elif self.trainer.validating or self.trainer.sanity_checking: 120 | return self.val_batch_transform(batch) 121 | else: 122 | raise RuntimeError( 123 | "Trainer state: \n" 124 | f"training: {self.trainer.training}\n" 125 | f"validating: {self.trainer.validating}\n" 126 | f"testing: {self.trainer.testing}\n" 127 | f"predicting: {self.trainer.predicting}\n" 128 | f"sanity_checking: {self.trainer.sanity_checking}" 129 | ) 130 | -------------------------------------------------------------------------------- /dataset/realesrgan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence 2 | import math 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils import data 9 | from PIL import Image 10 | 11 | from utils.degradation import circular_lowpass_kernel, random_mixed_kernels 12 | from utils.image import augment, random_crop_arr, center_crop_arr 13 | from utils.file import load_file_list 14 | 15 | 16 | class RealESRGANDataset(data.Dataset): 17 | """ 18 | # TODO: add comment 19 | """ 20 | 21 | def __init__( 22 | self, 23 | file_list: str, 24 | out_size: int, 25 | crop_type: str, 26 | use_hflip: bool, 27 | use_rot: bool, 28 | # blur kernel settings of the first degradation stage 29 | blur_kernel_size: int, 30 | kernel_list: Sequence[str], 31 | kernel_prob: Sequence[float], 32 | blur_sigma: Sequence[float], 33 | betag_range: Sequence[float], 34 | betap_range: Sequence[float], 35 | sinc_prob: float, 36 | # blur kernel settings of the second degradation stage 37 | blur_kernel_size2: int, 38 | kernel_list2: Sequence[str], 39 | kernel_prob2: Sequence[float], 40 | blur_sigma2: Sequence[float], 41 | betag_range2: Sequence[float], 42 | betap_range2: Sequence[float], 43 | sinc_prob2: float, 44 | final_sinc_prob: float 45 | ) -> "RealESRGANDataset": 46 | super(RealESRGANDataset, self).__init__() 47 | self.paths = load_file_list(file_list) 48 | self.out_size = out_size 49 | self.crop_type = crop_type 50 | assert self.crop_type in ["center", "random", "none"], f"invalid crop type: {self.crop_type}" 51 | 52 | self.blur_kernel_size = blur_kernel_size 53 | self.kernel_list = kernel_list 54 | # a list for each kernel probability 55 | self.kernel_prob = kernel_prob 56 | self.blur_sigma = blur_sigma 57 | # betag used in generalized Gaussian blur kernels 58 | self.betag_range = betag_range 59 | # betap used in plateau blur kernels 60 | self.betap_range = betap_range 61 | # the probability for sinc filters 62 | self.sinc_prob = sinc_prob 63 | 64 | self.blur_kernel_size2 = blur_kernel_size2 65 | self.kernel_list2 = kernel_list2 66 | self.kernel_prob2 = kernel_prob2 67 | self.blur_sigma2 = blur_sigma2 68 | self.betag_range2 = betag_range2 69 | self.betap_range2 = betap_range2 70 | self.sinc_prob2 = sinc_prob2 71 | 72 | # a final sinc filter 73 | self.final_sinc_prob = final_sinc_prob 74 | 75 | self.use_hflip = use_hflip 76 | self.use_rot = use_rot 77 | 78 | # kernel size ranges from 7 to 21 79 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] 80 | # TODO: kernel range is now hard-coded, should be in the configure file 81 | # convolving with pulse tensor brings no blurry effect 82 | self.pulse_tensor = torch.zeros(21, 21).float() 83 | self.pulse_tensor[10, 10] = 1 84 | 85 | @torch.no_grad() 86 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: 87 | # -------------------------------- Load hq images -------------------------------- # 88 | hq_path = self.paths[index] 89 | success = False 90 | for _ in range(3): 91 | try: 92 | pil_img = Image.open(hq_path).convert("RGB") 93 | success = True 94 | break 95 | except: 96 | time.sleep(1) 97 | assert success, f"failed to load image {hq_path}" 98 | 99 | if self.crop_type == "random": 100 | pil_img = random_crop_arr(pil_img, self.out_size) 101 | elif self.crop_type == "center": 102 | pil_img = center_crop_arr(pil_img, self.out_size) 103 | # self.crop_type is "none" 104 | else: 105 | pil_img = np.array(pil_img) 106 | assert pil_img.shape[:2] == (self.out_size, self.out_size) 107 | # hwc, rgb to bgr, [0, 255] to [0, 1], float32 108 | img_hq = (pil_img[..., ::-1] / 255.0).astype(np.float32) 109 | 110 | # -------------------- Do augmentation for training: flip, rotation -------------------- # 111 | img_hq = augment(img_hq, self.use_hflip, self.use_rot) 112 | 113 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ # 114 | kernel_size = random.choice(self.kernel_range) 115 | if np.random.uniform() < self.sinc_prob: 116 | # this sinc filter setting is for kernels ranging from [7, 21] 117 | if kernel_size < 13: 118 | omega_c = np.random.uniform(np.pi / 3, np.pi) 119 | else: 120 | omega_c = np.random.uniform(np.pi / 5, np.pi) 121 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 122 | else: 123 | kernel = random_mixed_kernels( 124 | self.kernel_list, 125 | self.kernel_prob, 126 | kernel_size, 127 | self.blur_sigma, 128 | self.blur_sigma, [-math.pi, math.pi], 129 | self.betag_range, 130 | self.betap_range, 131 | noise_range=None 132 | ) 133 | # pad kernel 134 | pad_size = (21 - kernel_size) // 2 135 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 136 | 137 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ # 138 | kernel_size = random.choice(self.kernel_range) 139 | if np.random.uniform() < self.sinc_prob2: 140 | if kernel_size < 13: 141 | omega_c = np.random.uniform(np.pi / 3, np.pi) 142 | else: 143 | omega_c = np.random.uniform(np.pi / 5, np.pi) 144 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 145 | else: 146 | kernel2 = random_mixed_kernels( 147 | self.kernel_list2, 148 | self.kernel_prob2, 149 | kernel_size, 150 | self.blur_sigma2, 151 | self.blur_sigma2, [-math.pi, math.pi], 152 | self.betag_range2, 153 | self.betap_range2, 154 | noise_range=None 155 | ) 156 | 157 | # pad kernel 158 | pad_size = (21 - kernel_size) // 2 159 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) 160 | 161 | # ------------------------------------- the final sinc kernel ------------------------------------- # 162 | if np.random.uniform() < self.final_sinc_prob: 163 | kernel_size = random.choice(self.kernel_range) 164 | omega_c = np.random.uniform(np.pi / 3, np.pi) 165 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 166 | sinc_kernel = torch.FloatTensor(sinc_kernel) 167 | else: 168 | sinc_kernel = self.pulse_tensor 169 | 170 | # [0, 1], BGR to RGB, HWC to CHW 171 | img_hq = torch.from_numpy( 172 | img_hq[..., ::-1].transpose(2, 0, 1).copy() 173 | ).float() 174 | kernel = torch.FloatTensor(kernel) 175 | kernel2 = torch.FloatTensor(kernel2) 176 | 177 | return { 178 | "hq": img_hq, "kernel1": kernel, "kernel2": kernel2, 179 | "sinc_kernel": sinc_kernel, "txt": "" 180 | } 181 | 182 | def __len__(self) -> int: 183 | return len(self.paths) 184 | -------------------------------------------------------------------------------- /dataset/test.py: -------------------------------------------------------------------------------- 1 | import rawpy 2 | if __name__ == "__main__": 3 | with rawpy.imread("/home/user001/zwl/data/Sony/long/00001_00_10s.ARW") as raw_target: 4 | pass -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import os 4 | import cv2 5 | import torch 6 | from torchvision import transforms 7 | from scipy.ndimage import gaussian_filter 8 | from lpips_pytorch import LPIPS, lpips 9 | from skimage.metrics import peak_signal_noise_ratio as psnr 10 | import pytorch_ssim 11 | from argparse import ArgumentParser 12 | from tqdm import tqdm 13 | import pdb 14 | 15 | 16 | import pyiqa 17 | # def psnr(img1, img2): 18 | # mse = np.mean((img1 - img2) ** 2 ) 19 | # if mse == 0: 20 | # return 100 21 | # return 20 * math.log10(255.0 / math.sqrt(mse)) 22 | def calculate_psnr(img1, img2): 23 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 24 | 25 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 26 | 27 | Args: 28 | img1 (ndarray): Images with range [0, 255]. 29 | img2 (ndarray): Images with range [0, 255]. 30 | crop_border (int): Cropped pixels in each edge of an image. These 31 | pixels are not involved in the PSNR calculation. 32 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 33 | Default: 'HWC'. 34 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 35 | 36 | Returns: 37 | float: psnr result. 38 | """ 39 | 40 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 41 | img1 = img1.astype(np.float64) 42 | img2 = img2.astype(np.float64) 43 | mse = np.mean((img1 - img2)**2) 44 | if mse == 0: 45 | return float('inf') 46 | 47 | #pdb.set_trace() 48 | return 10. * np.log10(255.*255.0 /mse) 49 | 50 | def _ssim(img1, img2): 51 | """Calculate SSIM (structural similarity) for one channel images. 52 | 53 | It is called by func:`calculate_ssim`. 54 | 55 | Args: 56 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 57 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 58 | 59 | Returns: 60 | float: ssim result. 61 | """ 62 | 63 | C1 = (0.01 * 255)**2 64 | C2 = (0.03 * 255)**2 65 | 66 | img1 = img1.astype(np.float64) 67 | img2 = img2.astype(np.float64) 68 | kernel = cv2.getGaussianKernel(11, 1.5) 69 | window = np.outer(kernel, kernel.transpose()) 70 | 71 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 72 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 73 | mu1_sq = mu1**2 74 | mu2_sq = mu2**2 75 | mu1_mu2 = mu1 * mu2 76 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 77 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 78 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 79 | 80 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 81 | return ssim_map.mean() 82 | def calculate_ssim(img1, img2): 83 | """Calculate SSIM (structural similarity). 84 | 85 | Ref: 86 | Image quality assessment: From error visibility to structural similarity 87 | 88 | The results are the same as that of the official released MATLAB code in 89 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 90 | 91 | For three-channel images, SSIM is calculated for each channel and then 92 | averaged. 93 | 94 | Args: 95 | img1 (ndarray): Images with range [0, 255]. 96 | img2 (ndarray): Images with range [0, 255]. 97 | crop_border (int): Cropped pixels in each edge of an image. These 98 | pixels are not involved in the SSIM calculation. 99 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 100 | Default: 'HWC'. 101 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 102 | 103 | Returns: 104 | float: ssim result. 105 | """ 106 | 107 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 108 | 109 | img1 = img1.astype(np.float64) 110 | img2 = img2.astype(np.float64) 111 | 112 | 113 | ssims = [] 114 | for i in range(img1.shape[2]): 115 | ssims.append(_ssim(img1[..., i], img2[..., i])) 116 | return np.array(ssims).mean() 117 | 118 | def main(path1, path2, type="all"): 119 | loss_1 = [] 120 | loss_2 = [] 121 | loss_3 = [] 122 | loss_4 = [] 123 | length = 0 124 | 125 | if type=="lpips" or type == "all": 126 | lpips = LPIPS().cuda() 127 | if type=="ssim" or type == "all": 128 | ssim = pytorch_ssim.SSIM(window_size = 11).cuda() 129 | 130 | #iqa_metric = pyiqa.create_metric(type, test_y_channel=False, color_space='ycbcr').cuda() 131 | lpips_metric = pyiqa.create_metric('lpips').cuda() 132 | ssim_metric = pyiqa.create_metric('ssim').cuda() 133 | musiq_metric = pyiqa.create_metric('musiq').cuda() 134 | 135 | for idx ,img in tqdm(enumerate(os.listdir(path1)),total=len(os.listdir(path1))): 136 | imgpath1 = os.path.join(path1,img) 137 | imgpath2 = os.path.join(path2,img) 138 | imgpath2 = imgpath2[:-3]+'png' 139 | 140 | #print(imgpath1) 141 | #img1 = cv2.imread(imgpath1).astype(np.float64) 142 | 143 | 144 | #img2 = cv2.imread(imgpath2).astype(np.float64) 145 | 146 | 147 | # mean_l = [] 148 | # std_l = [] 149 | # for j in range(3): 150 | # mean_l.append(np.mean(img2[:, :, j])) 151 | # std_l.append(np.std(img2[:, :, j])) 152 | # for j in range(3): 153 | # # correct twice 154 | # mean = np.mean(img1[:, :, j]) 155 | # img1[:, :, j] = img1[:, :, j] - mean + mean_l[j] 156 | # std = np.std(img1[:, :, j]) 157 | # img1[:, :, j] = img1[:, :, j] / std * std_l[j] 158 | 159 | # mean = np.mean(img1[:, :, j]) 160 | # img1[:, :, j] = img1[:, :, j] - mean + mean_l[j] 161 | # std = np.std(img1[:, :, j]) 162 | # img1[:, :, j] = img1[:, :, j] / std * std_l[j] 163 | # img1 = cv2.resize(img1,(256,256)) 164 | # img2 = cv2.resize(img2,(256,256)) 165 | # if img1.shape != img2.shape: 166 | # if img1.shape[0]< img2.shape[0]: 167 | # img2 = cv2.resize(img2,img1.shape[:2]) 168 | # else: 169 | # img1 = cv2.resize(img1,img2.shape[:2]) 170 | if type=="psnr": 171 | psnr_score = iqa_metric(imgpath1,imgpath2) 172 | # loss_1 += psnr(img1,img2,data_range=255.0) 173 | loss_1.append(psnr_score.cpu().numpy()) 174 | elif type=="lpips": 175 | lpips_score = lpips_metric(imgpath1,imgpath2) 176 | loss_2.append(lpips_score.cpu().numpy()) 177 | elif type=="ssim": 178 | ssim_score = ssim_metric(imgpath1,imgpath2) 179 | loss_3.append(ssim_score.cpu().numpy()) 180 | elif type=="musiq": 181 | musiq_score = musiq_metric(imgpath1) 182 | loss_4.append(musiq_score.cpu().numpy()) 183 | elif type == "all": 184 | loss_1 += psnr(img1,img2) 185 | loss_2 += lpips(transforms.ToTensor()(img1).cuda(),transforms.ToTensor()(img2).cuda()) 186 | loss_3 += ssim(transforms.ToTensor()(img1).unsqueeze(0).cuda(),transforms.ToTensor()(img2).unsqueeze(0).cuda()) 187 | # loss += criterion(transforms.ToTensor()(img1).cuda(),transforms.ToTensor()(img2).cuda()) 188 | # loss += criterion(transforms.ToTensor()(img1).unsqueeze(0).cuda(),transforms.ToTensor()(img2).unsqueeze(0).cuda()) 189 | length +=1 190 | if type=="psnr" or type == "all": 191 | print("psnr↑",np.mean(loss_1)) 192 | if type=="lpips" or type == "all": 193 | print("lpips↓",np.mean(loss_2)) 194 | if type=="ssim" or type == "all": 195 | print("ssim↑",np.mean(loss_3)) 196 | if type=="musiq" or type == "all": 197 | print("musiq↑",np.mean(loss_4)) 198 | 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = ArgumentParser() 203 | parser.add_argument("--input1", type=str, required=True) 204 | parser.add_argument("--input2", type=str, required=True) 205 | parser.add_argument("--type", type=str, default="all") 206 | args = parser.parse_args() 207 | main(args.input1, args.input2, args.type) 208 | 209 | 210 | ''' 211 | DiffBIR 212 | psnr↑ 31.14 213 | lpips↓ 0.2063 214 | ssim↑ 0.6731 215 | 216 | midd 217 | psnr↑ 30.87 218 | lpips↓ 0.2046 219 | ssim↑ 0.6719 220 | 221 | final 222 | psnr↑ 31.17 223 | lpips↓ 0.2248 224 | ssim↑ 0.7220 225 | 226 | ''' -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/xformers_state.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/xformers_state.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/xformers_state.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/xformers_state.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ldm.modules.midas.api import load_midas_transform 4 | 5 | 6 | class AddMiDaS(object): 7 | def __init__(self, model_type): 8 | super().__init__() 9 | self.transform = load_midas_transform(model_type) 10 | 11 | def pt2np(self, x): 12 | x = ((x + 1.0) * .5).detach().cpu().numpy() 13 | return x 14 | 15 | def np2pt(self, x): 16 | x = torch.from_numpy(x) * 2 - 1. 17 | return x 18 | 19 | def __call__(self, sample): 20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point 21 | x = self.pt2np(sample['jpg']) 22 | x = self.transform({"image": x})["image"] 23 | sample['midas_in'] = x 24 | return sample -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | import torch 3 | 4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 5 | 6 | 7 | MODEL_TYPES = { 8 | "eps": "noise", 9 | "v": "v" 10 | } 11 | 12 | 13 | class DPMSolverSampler(object): 14 | def __init__(self, model, **kwargs): 15 | super().__init__() 16 | self.model = model 17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 19 | 20 | def register_buffer(self, name, attr): 21 | if type(attr) == torch.Tensor: 22 | if attr.device != torch.device("cuda"): 23 | attr = attr.to(torch.device("cuda")) 24 | setattr(self, name, attr) 25 | 26 | @torch.no_grad() 27 | def sample(self, 28 | S, 29 | batch_size, 30 | shape, 31 | conditioning=None, 32 | callback=None, 33 | normals_sequence=None, 34 | img_callback=None, 35 | quantize_x0=False, 36 | eta=0., 37 | mask=None, 38 | x0=None, 39 | temperature=1., 40 | noise_dropout=0., 41 | score_corrector=None, 42 | corrector_kwargs=None, 43 | verbose=True, 44 | x_T=None, 45 | log_every_t=100, 46 | unconditional_guidance_scale=1., 47 | unconditional_conditioning=None, 48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 49 | **kwargs 50 | ): 51 | if conditioning is not None: 52 | if isinstance(conditioning, dict): 53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 54 | if cbs != batch_size: 55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 56 | else: 57 | if conditioning.shape[0] != batch_size: 58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type=MODEL_TYPES[self.model.parameterization], 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 86 | 87 | return x.to(device), None -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def norm_thresholding(x0, value): 15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 16 | return x0 * (value / s) 17 | 18 | 19 | def spatial_norm_thresholding(x0, value): 20 | # b c h w 21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 22 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/upscaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from functools import partial 5 | 6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule 7 | from ldm.util import default 8 | 9 | 10 | class AbstractLowScaleModel(nn.Module): 11 | # for concatenating a downsampled image to the latent representation 12 | def __init__(self, noise_schedule_config=None): 13 | super(AbstractLowScaleModel, self).__init__() 14 | if noise_schedule_config is not None: 15 | self.register_schedule(**noise_schedule_config) 16 | 17 | def register_schedule(self, beta_schedule="linear", timesteps=1000, 18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 20 | cosine_s=cosine_s) 21 | alphas = 1. - betas 22 | alphas_cumprod = np.cumprod(alphas, axis=0) 23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 24 | 25 | timesteps, = betas.shape 26 | self.num_timesteps = int(timesteps) 27 | self.linear_start = linear_start 28 | self.linear_end = linear_end 29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' 30 | 31 | to_torch = partial(torch.tensor, dtype=torch.float32) 32 | 33 | self.register_buffer('betas', to_torch(betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 43 | 44 | def q_sample(self, x_start, t, noise=None): 45 | noise = default(noise, lambda: torch.randn_like(x_start)) 46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) 48 | 49 | def forward(self, x): 50 | return x, None 51 | 52 | def decode(self, x): 53 | return x 54 | 55 | 56 | class SimpleImageConcat(AbstractLowScaleModel): 57 | # no noise level conditioning 58 | def __init__(self): 59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None) 60 | self.max_noise_level = 0 61 | 62 | def forward(self, x): 63 | # fix to constant noise level 64 | return x, torch.zeros(x.shape[0], device=x.device).long() 65 | 66 | 67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): 68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): 69 | super().__init__(noise_schedule_config=noise_schedule_config) 70 | self.max_noise_level = max_noise_level 71 | 72 | def forward(self, x, noise_level=None): 73 | if noise_level is None: 74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 75 | else: 76 | assert isinstance(noise_level, torch.Tensor) 77 | z = self.q_sample(x, noise_level) 78 | return z, noise_level 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.transforms import Compose 7 | 8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel 9 | from ldm.modules.midas.midas.midas_net import MidasNet 10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small 11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | 14 | ISL_PATHS = { 15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", 16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", 17 | "midas_v21": "", 18 | "midas_v21_small": "", 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | def load_midas_transform(model_type): 29 | # https://github.com/isl-org/MiDaS/blob/master/run.py 30 | # load transform only 31 | if model_type == "dpt_large": # DPT-Large 32 | net_w, net_h = 384, 384 33 | resize_mode = "minimal" 34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | elif model_type == "dpt_hybrid": # DPT-Hybrid 37 | net_w, net_h = 384, 384 38 | resize_mode = "minimal" 39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | 41 | elif model_type == "midas_v21": 42 | net_w, net_h = 384, 384 43 | resize_mode = "upper_bound" 44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 45 | 46 | elif model_type == "midas_v21_small": 47 | net_w, net_h = 256, 256 48 | resize_mode = "upper_bound" 49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | 51 | else: 52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 53 | 54 | transform = Compose( 55 | [ 56 | Resize( 57 | net_w, 58 | net_h, 59 | resize_target=None, 60 | keep_aspect_ratio=True, 61 | ensure_multiple_of=32, 62 | resize_method=resize_mode, 63 | image_interpolation_method=cv2.INTER_CUBIC, 64 | ), 65 | normalization, 66 | PrepareForNet(), 67 | ] 68 | ) 69 | 70 | return transform 71 | 72 | 73 | def load_model(model_type): 74 | # https://github.com/isl-org/MiDaS/blob/master/run.py 75 | # load network 76 | model_path = ISL_PATHS[model_type] 77 | if model_type == "dpt_large": # DPT-Large 78 | model = DPTDepthModel( 79 | path=model_path, 80 | backbone="vitl16_384", 81 | non_negative=True, 82 | ) 83 | net_w, net_h = 384, 384 84 | resize_mode = "minimal" 85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 86 | 87 | elif model_type == "dpt_hybrid": # DPT-Hybrid 88 | model = DPTDepthModel( 89 | path=model_path, 90 | backbone="vitb_rn50_384", 91 | non_negative=True, 92 | ) 93 | net_w, net_h = 384, 384 94 | resize_mode = "minimal" 95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 96 | 97 | elif model_type == "midas_v21": 98 | model = MidasNet(model_path, non_negative=True) 99 | net_w, net_h = 384, 384 100 | resize_mode = "upper_bound" 101 | normalization = NormalizeImage( 102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 103 | ) 104 | 105 | elif model_type == "midas_v21_small": 106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 107 | non_negative=True, blocks={'expand': True}) 108 | net_w, net_h = 256, 256 109 | resize_mode = "upper_bound" 110 | normalization = NormalizeImage( 111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 112 | ) 113 | 114 | else: 115 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 116 | assert False 117 | 118 | transform = Compose( 119 | [ 120 | Resize( 121 | net_w, 122 | net_h, 123 | resize_target=None, 124 | keep_aspect_ratio=True, 125 | ensure_multiple_of=32, 126 | resize_method=resize_mode, 127 | image_interpolation_method=cv2.INTER_CUBIC, 128 | ), 129 | normalization, 130 | PrepareForNet(), 131 | ] 132 | ) 133 | 134 | return model.eval(), transform 135 | 136 | 137 | class MiDaSInference(nn.Module): 138 | MODEL_TYPES_TORCH_HUB = [ 139 | "DPT_Large", 140 | "DPT_Hybrid", 141 | "MiDaS_small" 142 | ] 143 | MODEL_TYPES_ISL = [ 144 | "dpt_large", 145 | "dpt_hybrid", 146 | "midas_v21", 147 | "midas_v21_small", 148 | ] 149 | 150 | def __init__(self, model_type): 151 | super().__init__() 152 | assert (model_type in self.MODEL_TYPES_ISL) 153 | model, _ = load_model(model_type) 154 | self.model = model 155 | self.model.train = disabled_train 156 | 157 | def forward(self, x): 158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array 159 | # NOTE: we expect that the correct transform has been called during dataloading. 160 | with torch.no_grad(): 161 | prediction = self.model(x) 162 | prediction = torch.nn.functional.interpolate( 163 | prediction.unsqueeze(1), 164 | size=x.shape[2:], 165 | mode="bicubic", 166 | align_corners=False, 167 | ) 168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) 169 | return prediction 170 | 171 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/midas/midas/__init__.py -------------------------------------------------------------------------------- /ldm/modules/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /ldm/modules/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /ldm/modules/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torch import optim 5 | import numpy as np 6 | 7 | from inspect import isfunction 8 | from PIL import Image, ImageDraw, ImageFont 9 | 10 | 11 | def log_txt_as_img(wh, xc, size=10): 12 | # wh a tuple of (width, height) 13 | # xc a list of captions to plot 14 | b = len(xc) 15 | txts = list() 16 | for bi in range(b): 17 | txt = Image.new("RGB", wh, color="white") 18 | draw = ImageDraw.Draw(txt) 19 | # font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) 20 | font = ImageFont.load_default() 21 | nc = int(40 * (wh[0] / 256)) 22 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 23 | 24 | try: 25 | draw.text((0, 0), lines, fill="black", font=font) 26 | except UnicodeEncodeError: 27 | print("Cant encode string for logging. Skipping.") 28 | 29 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 30 | txts.append(txt) 31 | txts = np.stack(txts) 32 | txts = torch.tensor(txts) 33 | return txts 34 | 35 | 36 | def ismap(x): 37 | if not isinstance(x, torch.Tensor): 38 | return False 39 | return (len(x.shape) == 4) and (x.shape[1] > 3) 40 | 41 | 42 | def isimage(x): 43 | if not isinstance(x,torch.Tensor): 44 | return False 45 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 46 | 47 | 48 | def exists(x): 49 | return x is not None 50 | 51 | 52 | def default(val, d): 53 | if exists(val): 54 | return val 55 | return d() if isfunction(d) else d 56 | 57 | 58 | def mean_flat(tensor): 59 | """ 60 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 61 | Take the mean over all non-batch dimensions. 62 | """ 63 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 64 | 65 | 66 | def count_params(model, verbose=False): 67 | total_params = sum(p.numel() for p in model.parameters()) 68 | if verbose: 69 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 70 | return total_params 71 | 72 | 73 | def instantiate_from_config(config): 74 | if not "target" in config: 75 | if config == '__is_first_stage__': 76 | return None 77 | elif config == "__is_unconditional__": 78 | return None 79 | raise KeyError("Expected key `target` to instantiate.") 80 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 81 | 82 | 83 | def get_obj_from_str(string, reload=False): 84 | module, cls = string.rsplit(".", 1) 85 | if reload: 86 | module_imp = importlib.import_module(module) 87 | importlib.reload(module_imp) 88 | return getattr(importlib.import_module(module, package=None), cls) 89 | 90 | 91 | class AdamWwithEMAandWings(optim.Optimizer): 92 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 93 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 94 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 95 | ema_power=1., param_names=()): 96 | """AdamW that saves EMA versions of the parameters.""" 97 | if not 0.0 <= lr: 98 | raise ValueError("Invalid learning rate: {}".format(lr)) 99 | if not 0.0 <= eps: 100 | raise ValueError("Invalid epsilon value: {}".format(eps)) 101 | if not 0.0 <= betas[0] < 1.0: 102 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 103 | if not 0.0 <= betas[1] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 105 | if not 0.0 <= weight_decay: 106 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 107 | if not 0.0 <= ema_decay <= 1.0: 108 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 109 | defaults = dict(lr=lr, betas=betas, eps=eps, 110 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 111 | ema_power=ema_power, param_names=param_names) 112 | super().__init__(params, defaults) 113 | 114 | def __setstate__(self, state): 115 | super().__setstate__(state) 116 | for group in self.param_groups: 117 | group.setdefault('amsgrad', False) 118 | 119 | @torch.no_grad() 120 | def step(self, closure=None): 121 | """Performs a single optimization step. 122 | Args: 123 | closure (callable, optional): A closure that reevaluates the model 124 | and returns the loss. 125 | """ 126 | loss = None 127 | if closure is not None: 128 | with torch.enable_grad(): 129 | loss = closure() 130 | 131 | for group in self.param_groups: 132 | params_with_grad = [] 133 | grads = [] 134 | exp_avgs = [] 135 | exp_avg_sqs = [] 136 | ema_params_with_grad = [] 137 | state_sums = [] 138 | max_exp_avg_sqs = [] 139 | state_steps = [] 140 | amsgrad = group['amsgrad'] 141 | beta1, beta2 = group['betas'] 142 | ema_decay = group['ema_decay'] 143 | ema_power = group['ema_power'] 144 | 145 | for p in group['params']: 146 | if p.grad is None: 147 | continue 148 | params_with_grad.append(p) 149 | if p.grad.is_sparse: 150 | raise RuntimeError('AdamW does not support sparse gradients') 151 | grads.append(p.grad) 152 | 153 | state = self.state[p] 154 | 155 | # State initialization 156 | if len(state) == 0: 157 | state['step'] = 0 158 | # Exponential moving average of gradient values 159 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 160 | # Exponential moving average of squared gradient values 161 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 162 | if amsgrad: 163 | # Maintains max of all exp. moving avg. of sq. grad. values 164 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 165 | # Exponential moving average of parameter values 166 | state['param_exp_avg'] = p.detach().float().clone() 167 | 168 | exp_avgs.append(state['exp_avg']) 169 | exp_avg_sqs.append(state['exp_avg_sq']) 170 | ema_params_with_grad.append(state['param_exp_avg']) 171 | 172 | if amsgrad: 173 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 174 | 175 | # update the steps for each param group update 176 | state['step'] += 1 177 | # record the step after step update 178 | state_steps.append(state['step']) 179 | 180 | optim._functional.adamw(params_with_grad, 181 | grads, 182 | exp_avgs, 183 | exp_avg_sqs, 184 | max_exp_avg_sqs, 185 | state_steps, 186 | amsgrad=amsgrad, 187 | beta1=beta1, 188 | beta2=beta2, 189 | lr=group['lr'], 190 | weight_decay=group['weight_decay'], 191 | eps=group['eps'], 192 | maximize=False) 193 | 194 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 195 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 196 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 197 | 198 | return loss -------------------------------------------------------------------------------- /ldm/xformers_state.py: -------------------------------------------------------------------------------- 1 | try: 2 | import xformers 3 | import xformers.ops 4 | XFORMERS_IS_AVAILBLE = True 5 | except: 6 | XFORMERS_IS_AVAILBLE = False 7 | print("No module 'xformers'. Proceeding without it.") 8 | 9 | 10 | def is_xformers_available() -> bool: 11 | global XFORMERS_IS_AVAILBLE 12 | return XFORMERS_IS_AVAILBLE 13 | 14 | def disable_xformers() -> None: 15 | print("DISABLE XFORMERS!") 16 | global XFORMERS_IS_AVAILBLE 17 | XFORMERS_IS_AVAILBLE = False 18 | 19 | def enable_xformers() -> None: 20 | print("ENABLE XFORMERS!") 21 | global XFORMERS_IS_AVAILBLE 22 | XFORMERS_IS_AVAILBLE = True 23 | 24 | def auto_xformers_status(device): 25 | if 'cuda' in str(device): 26 | enable_xformers() 27 | elif str(device) == 'cpu': 28 | disable_xformers() 29 | else: 30 | raise ValueError(f"Unknown device {device}") 31 | -------------------------------------------------------------------------------- /model/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cldm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cldm.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/cldm_bsr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cldm_bsr.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/cond_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cond_fn.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mixins.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/mixins.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/mixins.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/mixins.cpython-39.pyc -------------------------------------------------------------------------------- /model/__pycache__/spaced_sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/spaced_sampler.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/swinir.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/swinir.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/swinir.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/swinir.cpython-39.pyc -------------------------------------------------------------------------------- /model/callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | import os 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.utilities.types import STEP_OUTPUT 8 | import torch 9 | import torchvision 10 | from PIL import Image 11 | from pytorch_lightning.callbacks import Callback 12 | from pytorch_lightning.utilities.distributed import rank_zero_only 13 | 14 | from .mixins import ImageLoggerMixin 15 | 16 | 17 | __all__ = [ 18 | "ModelCheckpoint", 19 | "ImageLogger" 20 | ] 21 | 22 | class ImageLogger(Callback): 23 | """ 24 | Log images during training or validating. 25 | 26 | TODO: Support validating. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | log_every_n_steps: int=2000, 32 | max_images_each_step: int=4, 33 | log_images_kwargs: Dict[str, Any]=None 34 | ) -> "ImageLogger": 35 | super().__init__() 36 | self.log_every_n_steps = log_every_n_steps 37 | self.max_images_each_step = max_images_each_step 38 | self.log_images_kwargs = log_images_kwargs or dict() 39 | 40 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 41 | assert isinstance(pl_module, ImageLoggerMixin) 42 | 43 | @rank_zero_only 44 | def on_train_batch_end( 45 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, 46 | batch: Any, batch_idx: int, dataloader_idx: int 47 | ) -> None: 48 | if pl_module.global_step % self.log_every_n_steps == 0: 49 | is_train = pl_module.training 50 | if is_train: 51 | pl_module.freeze() 52 | 53 | with torch.no_grad(): 54 | # returned images should be: nchw, rgb, [0, 1] 55 | images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs) 56 | 57 | # save images 58 | save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train") 59 | os.makedirs(save_dir, exist_ok=True) 60 | for image_key in images: 61 | image = images[image_key].detach().cpu() 62 | N = min(self.max_images_each_step, len(image)) 63 | grid = torchvision.utils.make_grid(image[:N], nrow=4) 64 | # chw -> hwc (hw if gray) 65 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() 66 | grid = (grid * 255).clip(0, 255).astype(np.uint8) 67 | filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format( 68 | image_key, pl_module.global_step, pl_module.current_epoch, batch_idx 69 | ) 70 | path = os.path.join(save_dir, filename) 71 | Image.fromarray(grid).save(path) 72 | 73 | if is_train: 74 | pl_module.unfreeze() 75 | -------------------------------------------------------------------------------- /model/cond_fn.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Optional 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Guidance: 7 | 8 | def __init__( 9 | self, 10 | scale: float, 11 | t_start: int, 12 | t_stop: int, 13 | space: str, 14 | repeat: int 15 | ) -> "Guidance": 16 | """ 17 | Initialize latent image guidance. 18 | 19 | Args: 20 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 21 | the closer the final result will be to the output of the first stage model. 22 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 23 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`. 24 | space (str): The data space for computing loss function (rgb or latent). 25 | repeat (int): Repeat gradient descent for `repeat` times. 26 | 27 | Our latent image guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior). 28 | Thanks for their work! 29 | """ 30 | self.scale = scale 31 | self.t_start = t_start 32 | self.t_stop = t_stop 33 | self.target = None 34 | self.space = space 35 | self.repeat = repeat 36 | 37 | def load_target(self, target: torch.Tensor) -> torch.Tensor: 38 | self.target = target 39 | 40 | def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Optional[torch.Tensor]: 41 | if self.t_stop < t and t < self.t_start: 42 | # print("sampling with classifier guidance") 43 | # avoid propagating gradient out of this scope 44 | pred_x0 = pred_x0.detach().clone() 45 | target_x0 = target_x0.detach().clone() 46 | return self.scale * self._forward(target_x0, pred_x0) 47 | else: 48 | return None 49 | 50 | @overload 51 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor: 52 | ... 53 | 54 | 55 | class MSEGuidance(Guidance): 56 | 57 | def __init__( 58 | self, 59 | scale: float, 60 | t_start: int, 61 | t_stop: int, 62 | space: str, 63 | repeat: int 64 | ) -> "MSEGuidance": 65 | super().__init__( 66 | scale, t_start, t_stop, space, repeat 67 | ) 68 | 69 | @torch.enable_grad() 70 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor: 71 | # inputs: [-1, 1], nchw, rgb 72 | pred_x0.requires_grad_(True) 73 | 74 | # This is what we actually use. 75 | loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() 76 | 77 | print(f"loss = {loss.item()}") 78 | return -torch.autograd.grad(loss, pred_x0)[0] 79 | -------------------------------------------------------------------------------- /model/mixins.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Any, Dict 2 | import torch 3 | 4 | 5 | class ImageLoggerMixin: 6 | 7 | @overload 8 | def log_images(self, batch: Any, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]: 9 | ... 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu116 2 | torch==1.13.1+cu116 3 | torchvision==0.14.1+cu116 4 | xformers==0.0.16 5 | pytorch_lightning==1.4.2\ 6 | einops 7 | open-clip-torch==2.24.0 8 | omegaconf 9 | torchmetrics==0.6.0 10 | triton==2.0.0 11 | lora-diffusion==0.1.7 12 | opencv-python-headless 13 | scipy 14 | matplotlib 15 | lpips 16 | gradio 17 | chardet 18 | transformers 19 | facexlib 20 | -------------------------------------------------------------------------------- /scripts/inference_stage1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | import os 4 | from argparse import ArgumentParser, Namespace 5 | 6 | import pytorch_lightning as pl 7 | from omegaconf import OmegaConf 8 | import torch 9 | from PIL import Image 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from utils.image import auto_resize, pad 14 | from utils.common import load_state_dict, instantiate_from_config 15 | from utils.file import list_image_files, get_file_name_parts 16 | 17 | 18 | def parse_args() -> Namespace: 19 | parser = ArgumentParser() 20 | parser.add_argument("--config", type=str, required=True) 21 | parser.add_argument("--ckpt", type=str, required=True) 22 | parser.add_argument("--input", type=str, required=True) 23 | parser.add_argument("--sr_scale", type=float, default=1) 24 | parser.add_argument("--image_size", type=int, default=512) 25 | parser.add_argument("--show_lq", action="store_true") 26 | parser.add_argument("--resize_back", action="store_true") 27 | parser.add_argument("--output", type=str, required=True) 28 | parser.add_argument("--skip_if_exist", action="store_true") 29 | parser.add_argument("--seed", type=int, default=231) 30 | return parser.parse_args() 31 | 32 | 33 | @torch.no_grad() 34 | def main(): 35 | args = parse_args() 36 | pl.seed_everything(args.seed) 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | 39 | model: pl.LightningModule = instantiate_from_config(OmegaConf.load(args.config)) 40 | load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True) 41 | model.freeze() 42 | model.to(device) 43 | 44 | assert os.path.isdir(args.input) 45 | 46 | pbar = tqdm(list_image_files(args.input, follow_links=True)) 47 | for file_path in pbar: 48 | pbar.set_description(file_path) 49 | save_path = os.path.join(args.output, os.path.relpath(file_path, args.input)) 50 | parent_path, stem, _ = get_file_name_parts(save_path) 51 | save_path = os.path.join(parent_path, f"{stem}.png") 52 | if os.path.exists(save_path): 53 | if args.skip_if_exist: 54 | print(f"skip {save_path}") 55 | continue 56 | else: 57 | raise RuntimeError(f"{save_path} already exist") 58 | os.makedirs(parent_path, exist_ok=True) 59 | 60 | # load low-quality image and resize 61 | lq = Image.open(file_path).convert("RGB") 62 | if args.sr_scale != 1: 63 | lq = lq.resize( 64 | tuple(int(x * args.sr_scale) for x in lq.size), Image.BICUBIC 65 | ) 66 | lq_resized = auto_resize(lq, args.image_size) 67 | # padding 68 | x = pad(np.array(lq_resized), scale=64) 69 | 70 | x = torch.tensor(x, dtype=torch.float32, device=device) / 255.0 71 | x = x.permute(2, 0, 1).unsqueeze(0).contiguous() 72 | try: 73 | pred = model(x).detach().squeeze(0).permute(1, 2, 0) * 255 74 | pred = pred.clamp(0, 255).to(torch.uint8).cpu().numpy() 75 | except RuntimeError as e: 76 | print(f"inference failed, error: {e}") 77 | continue 78 | 79 | # remove padding 80 | pred = pred[:lq_resized.height, :lq_resized.width, :] 81 | if args.show_lq: 82 | if args.resize_back: 83 | lq = np.array(lq) 84 | if lq_resized.size != lq.size: 85 | pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS)) 86 | else: 87 | lq = np.array(lq_resized) 88 | final_image = Image.fromarray(np.concatenate([lq, pred], axis=1)) 89 | else: 90 | if args.resize_back and lq_resized.size != lq.size: 91 | final_image = Image.fromarray(pred).resize(lq.size, Image.LANCZOS) 92 | else: 93 | final_image = Image.fromarray(pred) 94 | final_image.save(save_path) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /scripts/make_file_list.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | import os 4 | from argparse import ArgumentParser 5 | 6 | from utils.file import list_image_files 7 | import cv2 8 | import pdb 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument("--img_folder", type=str, required=True) 12 | parser.add_argument("--val_size", type=int, default=0) 13 | parser.add_argument("--save_folder", type=str, required=True) 14 | parser.add_argument("--follow_links", action="store_true") 15 | args = parser.parse_args() 16 | 17 | files = list_image_files( 18 | args.img_folder, exts=(".jpg", ".png", ".jpeg"), follow_links=args.follow_links, 19 | log_progress=True, log_every_n_files=10000 20 | ) 21 | 22 | print(f"find {len(files)} images in {args.img_folder}") 23 | assert args.val_size < len(files) 24 | 25 | # val_files = files[:args.val_size] 26 | # train_files = files[args.val_size:] 27 | val_files = files 28 | print(len(val_files)) 29 | valid_files = [] 30 | for i,path in enumerate(val_files): 31 | 32 | valid_files.append(path) 33 | 34 | 35 | print('Total files:{}'.format(len(valid_files))) 36 | 37 | 38 | os.makedirs(args.save_folder, exist_ok=True) 39 | 40 | # with open(os.path.join(args.save_folder, "train.list"), "w") as fp: 41 | # for file_path in train_files: 42 | # fp.write(f"{file_path}\n") 43 | 44 | with open(os.path.join(args.save_folder, "train.list"), "w") as fp: 45 | for file_path in valid_files: 46 | fp.write(f"{file_path}\n") 47 | -------------------------------------------------------------------------------- /scripts/make_list_celea.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | import os 4 | from argparse import ArgumentParser 5 | import pandas as pd 6 | 7 | s_img = '/data1/zyx/CelebAMask-HQ/CelebA-HQ-img' 8 | 9 | 10 | 11 | train_count = 0 12 | test_count = 0 13 | val_count = 0 14 | image_list = pd.read_csv('/data1/zyx/CelebAMask-HQ/CelebA-HQ-to-CelebA-mapping.txt', delim_whitespace=True, header=None) 15 | train_paths = [] 16 | val_paths = [] 17 | test_paths = [] 18 | 19 | for idx, x in enumerate(image_list.loc[:, 1]): 20 | print (idx, x) 21 | if idx == 0: 22 | continue 23 | 24 | x = int(x) 25 | if x >= 162771 and x < 182638: 26 | img_path = os.path.join(s_img, str(idx-1)+'.jpg') 27 | val_paths.append(img_path) 28 | val_count += 1 29 | 30 | elif x >= 182638: 31 | img_path = os.path.join(s_img, str(idx-1)+'.jpg') 32 | test_paths.append(img_path) 33 | test_count += 1 34 | else: 35 | img_path = os.path.join(s_img, str(idx-1)+'.jpg') 36 | train_paths.append(img_path) 37 | train_count += 1 38 | 39 | print (train_count + test_count + val_count) 40 | 41 | 42 | 43 | 44 | save_folder = '/data1/zyx/FFHQ512' 45 | 46 | # with open(os.path.join(save_folder, "train.list"), "a") as fp: 47 | # for file_path in train_paths: 48 | # fp.write(f"{file_path}\n") 49 | 50 | # with open(os.path.join(save_folder, "val.list"), "a") as fp: 51 | # for file_path in val_paths: 52 | # fp.write(f"{file_path}\n") 53 | 54 | with open(os.path.join(save_folder, "test.list"), "w") as fp: 55 | for file_path in test_paths: 56 | fp.write(f"{file_path}\n") -------------------------------------------------------------------------------- /scripts/make_stage2_init_weight.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from argparse import ArgumentParser 4 | from typing import Dict 5 | 6 | import torch 7 | from omegaconf import OmegaConf 8 | 9 | from utils.common import instantiate_from_config 10 | import os 11 | os.environ['CURL_CA_BUNDLE'] = '' 12 | 13 | def load_weight(weight_path: str) -> Dict[str, torch.Tensor]: 14 | weight = torch.load(weight_path) 15 | if "state_dict" in weight: 16 | weight = weight["state_dict"] 17 | 18 | pure_weight = {} 19 | for key, val in weight.items(): 20 | if key.startswith("module."): 21 | key = key[len("module."):] 22 | pure_weight[key] = val 23 | 24 | return pure_weight 25 | 26 | parser = ArgumentParser() 27 | parser.add_argument("--cldm_config", type=str, required=True) 28 | parser.add_argument("--sd_weight", type=str, required=True) 29 | parser.add_argument("--swinir_weight", type=str, required=True) 30 | parser.add_argument("--output", type=str, required=True) 31 | args = parser.parse_args() 32 | 33 | model = instantiate_from_config(OmegaConf.load(args.cldm_config)) 34 | 35 | sd_weights = load_weight(args.sd_weight) 36 | swinir_weights = load_weight(args.swinir_weight) 37 | scratch_weights = model.state_dict() 38 | 39 | init_weights = {} 40 | for weight_name in scratch_weights.keys(): 41 | # find target pretrained weights for this weight 42 | if weight_name.startswith("control_"): 43 | suffix = weight_name[len("control_"):] 44 | target_name = f"model.diffusion_{suffix}" 45 | target_model_weights = sd_weights 46 | elif weight_name.startswith("preprocess_model."): 47 | suffix = weight_name[len("preprocess_model."):] 48 | target_name = suffix 49 | target_model_weights = swinir_weights 50 | elif weight_name.startswith("cond_encoder."): 51 | suffix = weight_name[len("cond_encoder."):] 52 | target_name = F"first_stage_model.{suffix}" 53 | target_model_weights = sd_weights 54 | else: 55 | target_name = weight_name 56 | target_model_weights = sd_weights 57 | 58 | # if target weight exist in pretrained model 59 | print(f"copy weights: {target_name} -> {weight_name}") 60 | if target_name in target_model_weights: 61 | # get pretrained weight 62 | target_weight = target_model_weights[target_name] 63 | target_shape = target_weight.shape 64 | model_shape = scratch_weights[weight_name].shape 65 | # if pretrained weight has the same shape with model weight, we make a copy 66 | if model_shape == target_shape: 67 | init_weights[weight_name] = target_weight.clone() 68 | # else we copy pretrained weight with additional channels initialized to zero 69 | else: 70 | newly_added_channels = model_shape[1] - target_shape[1] 71 | oc, _, h, w = target_shape 72 | zero_weight = torch.zeros((oc, newly_added_channels, h, w)).type_as(target_weight) 73 | init_weights[weight_name] = torch.cat((target_weight.clone(), zero_weight), dim=1) 74 | print(f"add zero weight to {target_name} in pretrained weights, newly added channels = {newly_added_channels}") 75 | else: 76 | init_weights[weight_name] = scratch_weights[weight_name].clone() 77 | print(f"These weights are newly added: {weight_name}") 78 | 79 | model.load_state_dict(init_weights, strict=True) 80 | torch.save(model.state_dict(), args.output) 81 | print("Done.") 82 | -------------------------------------------------------------------------------- /scripts/merge_img.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def merge_images_horizontally(image_paths, output_path): 4 | images = [Image.open(image_path) for image_path in image_paths] 5 | 6 | # Ensure that all images have the same height 7 | max_height = 512 8 | images = [img if img.height == max_height else img.resize((img.width * max_height // img.height, max_height)) for img in images] 9 | 10 | # Calculate the total width for the merged image 11 | total_width = sum(img.width for img in images) 12 | 13 | # Create a blank image with the total width and the maximum height 14 | merged_image = Image.new('RGB', (total_width, max_height)) 15 | 16 | # Paste each image into the merged image 17 | x_offset = 0 18 | for img in images: 19 | merged_image.paste(img, (x_offset, 0)) 20 | x_offset += img.width 21 | 22 | # Save the merged image 23 | merged_image.save(output_path) 24 | 25 | if __name__ == "__main__": 26 | image_paths_10955 = ["/home/user001/zwl/zyx/CodeFormer-master/inp6/lq/10955.png", 27 | "/home/user001/zwl/zyx/GPEN-main/eval_img_test/10955.png", 28 | "/home/user001/zwl/zyx/CodeFormer-master/inp6/10955.png", 29 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re3/z2/10955.png", 30 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re3/hq/10955.png"] 31 | 32 | image_paths_1777 = ["/home/user001/zwl/zyx/CodeFormer-master/inp11/lq/1777.png", 33 | "/home/user001/zwl/zyx/GPEN-main/eval_img_test/1777.png", 34 | "/home/user001/zwl/zyx/CodeFormer-master/inp11/1777.png", 35 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re10/z2/1777.png", 36 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re10/hq/1777.png"] 37 | 38 | output_path = "merged_image_10955.png" 39 | 40 | merge_images_horizontally(image_paths_10955, output_path) 41 | -------------------------------------------------------------------------------- /scripts/metrics.py: -------------------------------------------------------------------------------- 1 | import pyiqa 2 | import torch 3 | 4 | # list all available metrics 5 | print(pyiqa.list_models()) 6 | 7 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 8 | 9 | 10 | 11 | # For FID metric, use directory or precomputed statistics as inputs 12 | # refer to clean-fid for more details: https://github.com/GaParmar/clean-fid 13 | fid_metric = pyiqa.create_metric('fid') 14 | score = fid_metric('/home/zyx/CodeFormer-master/results/celeba_512_validation_lq_0.5/restored_faces/', '/data1/zyx/FFHQ512/FFHQ_512/') 15 | print(score) 16 | #score = fid_metric('./ResultsCalibra/dist_dir/', dataset_name="FFHQ", dataset_res=1024, dataset_split="trainval70k") -------------------------------------------------------------------------------- /scripts/rainy.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | from os.path import join 5 | import pdb 6 | def compute_rainy_layer(derained_image_path, groundtruth_image_path, output_image_dir,filename): 7 | # Load derained and groundtruth images 8 | derained_image = cv2.imread(derained_image_path, cv2.IMREAD_GRAYSCALE) 9 | 10 | gt_image = cv2.imread(groundtruth_image_path, cv2.IMREAD_GRAYSCALE) 11 | w,h = derained_image.shape[0],derained_image.shape[1] 12 | gt_image = cv2.resize(gt_image,(h,w)) 13 | 14 | # Apply thresholding to groundtruth image 15 | # Calculate the difference between the GT and result images 16 | difference_image = gt_image - derained_image 17 | #print(difference_image[:9,:9]) 18 | #pdb.set_trace() 19 | # Apply a threshold to obtain the rainy layer 20 | threshold = 20 # You can adjust this value based on your application 21 | is_rainy = ((difference_image > threshold) * (difference_image <200)) 22 | rainy_layer = np.where(is_rainy, 1, 0) 23 | # Save the rainy layer 24 | output_image_path = join(output_image_dir,filename) 25 | cv2.imwrite(output_image_path, rainy_layer * 255) 26 | print(output_image_path) 27 | #cv2.imwrite(output_image_path, normalized_difference) 28 | 29 | if __name__ == "__main__": 30 | derained_image_path = "/home/user001/zwl/zyx/Diffbir/outputs/swin_derain/rain-001.png" 31 | #derained_image_path = "/home/user001/zwl/zyx/Pretrained-IPT/experiment/results/ipt/results-DIV2K/rain-001_x1_SR.png" 32 | groundtruth_image_path = "/home/user001/zwl/data/Derain/Rain100L/rainy/rain-001.png" 33 | output_image_dir ="/home/user001/zwl/zyx/RCDNet-master/RCDNet_code/for_syn/experiment/RCDNet_test/results//rainy" # "/home/user001/zwl/zyx/Pretrained-IPT/experiment/results/ipt/results-DIV2K/rainy/" 34 | os.makedirs(output_image_dir,exist_ok=True) 35 | derained_image_dir = "/home/user001/zwl/zyx/RCDNet-master/RCDNet_code/for_syn/experiment/RCDNet_test/results/" 36 | groundtruth_image_dir = "/home/user001/zwl/data/Derain/Rain100L/rainy" 37 | for filename in os.listdir(derained_image_dir): 38 | if filename[-3:] != 'png': 39 | continue 40 | derain_image_path = join(derained_image_dir,filename) 41 | groundtruth_image_path = join(groundtruth_image_dir,filename) 42 | compute_rainy_layer(derain_image_path, groundtruth_image_path, output_image_dir,filename) 43 | -------------------------------------------------------------------------------- /scripts/sample_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from argparse import ArgumentParser 4 | import os 5 | from typing import Any 6 | 7 | from omegaconf import OmegaConf 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | from PIL import Image 11 | import pytorch_lightning as pl 12 | 13 | from utils.common import instantiate_from_config 14 | 15 | 16 | def wrap_dataloader(data_loader: DataLoader) -> Any: 17 | while True: 18 | yield from data_loader 19 | 20 | 21 | pl.seed_everything(231, workers=True) 22 | 23 | parser = ArgumentParser() 24 | parser.add_argument("--config", type=str, required=True) 25 | parser.add_argument("--sample_size", type=int, default=128) 26 | parser.add_argument("--show_gt", action="store_true") 27 | parser.add_argument("--output", type=str, required=True) 28 | args = parser.parse_args() 29 | 30 | config = OmegaConf.load(args.config) 31 | dataset = instantiate_from_config(config.dataset) 32 | transform = instantiate_from_config(config.batch_transform) 33 | data_loader = wrap_dataloader(DataLoader(dataset, batch_size=1, shuffle=True)) 34 | 35 | cnt = 0 36 | os.makedirs(args.output, exist_ok=True) 37 | 38 | for batch in data_loader: 39 | batch = transform(batch) 40 | for hq, lq in zip(batch["jpg"], batch["hint"]): 41 | hq = ((hq + 1) * 127.5).numpy().clip(0, 255).astype(np.uint8) 42 | lq = (lq * 255.0).numpy().clip(0, 255).astype(np.uint8) 43 | if args.show_gt: 44 | Image.fromarray(np.concatenate([hq, lq], axis=1)).save(os.path.join(args.output, f"{cnt}.png")) 45 | else: 46 | Image.fromarray(lq).save(os.path.join(args.output, f"{cnt}.png")) 47 | cnt += 1 48 | if cnt >= args.sample_size: 49 | break 50 | if cnt >= args.sample_size: 51 | break 52 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = "7" 3 | from argparse import ArgumentParser 4 | import cv2 5 | import os 6 | #os.environ['CUDA_VISIBLE_DEVICE'] = '7' 7 | import pytorch_lightning as pl 8 | from omegaconf import OmegaConf 9 | import torch 10 | from tqdm import tqdm 11 | from utils.common import instantiate_from_config, load_state_dict 12 | from torchvision.utils import save_image 13 | import torch.distributed as dist 14 | from torch.nn.parallel import DistributedDataParallel 15 | 16 | def main() -> None: 17 | 18 | # dist.init_process_group(backend='nccl', init_method='env://') 19 | # rank = torch.cuda.set_device(dist.get_rank()) 20 | # device = torch.device('cuda', rank) 21 | 22 | parser = ArgumentParser() 23 | parser.add_argument("--config", type=str, required=True) 24 | parser.add_argument("--output", type=str, required=True) 25 | parser.add_argument("--watch_step", action='store_true') 26 | # parser.add_argument("--local_rank", type=int, default=0) 27 | args = parser.parse_args() 28 | 29 | config = OmegaConf.load(args.config) 30 | pl.seed_everything(config.lightning.seed, workers=True) 31 | model_config = OmegaConf.load(config.model.config) 32 | #print(model_config) 33 | model_config['params']['output'] = args.output 34 | 35 | data_module = instantiate_from_config(config.data) 36 | data_module.setup(stage="fit") 37 | model = instantiate_from_config(model_config) 38 | if config.model.get("resume"): 39 | model_dict = torch.load(config.model.resume, map_location="cpu") 40 | model_dict = model_dict['state_dict'] if 'state_dict' in model_dict.keys() else model_dict 41 | a,b = model.load_state_dict(model_dict, strict=False) 42 | print("missing_keys:",a) 43 | print("unexpected_keys:",b) 44 | print("{} model has been loaded!".format(config.model.resume)) 45 | load_state_dict(model.preprocess_model, torch.load('/home/user001/zwl/data/flowir_work_dirs/swin_inp0/lightning_logs/version_0/checkpoints/step=39999.ckpt', map_location="cpu"), strict=True) 46 | #swin_dict = torch.load(config.model.resume, map_location="cpu") 47 | #torch.cuda.empty_cache() 48 | # model.to(device) 49 | #model.cuda() 50 | #model = DistributedDataParallel(model,device_ids=[rank],output_device=rank) 51 | model.eval() 52 | 53 | 54 | save_path = args.output 55 | # final_path = os.path.join(save_path,"final") 56 | # midd_path = os.path.join(save_path,"midd") 57 | # os.makedirs(final_path,exist_ok=True) 58 | # os.makedirs(midd_path,exist_ok=True) 59 | 60 | callbacks = [] 61 | for callback_config in config.lightning.callbacks: 62 | callbacks.append(instantiate_from_config(callback_config)) 63 | trainer = pl.Trainer(callbacks=callbacks, **config.lightning.trainer) 64 | 65 | testloader = data_module.val_dataloader() 66 | # for batch_idx, batch in tqdm(enumerate(testloader),total=len(testloader)): 67 | # batch['jpg'] = batch['jpg'].to(device) 68 | # batch['hint'] = batch['hint'].to(device) 69 | # imgname_batch = batch['imgname'] 70 | # log = model.module.validation_step(batch,args.watch_step) 71 | trainer.test(model,test_dataloaders=testloader) 72 | #assert False 73 | #images = log['samples'] 74 | #images_midd = log['samples_3'] 75 | # save_batch(images=images, 76 | # imgname_batch=imgname_batch, 77 | # save_path=final_path, 78 | # watch_step=False) 79 | # save_batch(images=images_midd, 80 | # imgname_batch=imgname_batch, 81 | # save_path=midd_path, 82 | # watch_step=False) 83 | 84 | def save_batch(images,imgname_batch, save_path, watch_step=False): 85 | if watch_step: 86 | for list_idx, img_list in enumerate(images): 87 | for img_idx, img in enumerate(img_list): 88 | imgname = str(list_idx)+"_"+imgname_batch[img_idx] 89 | save_img = os.path.join(save_path,imgname) 90 | save_image(img,save_img) 91 | else: 92 | for img_idx, img in enumerate(images): 93 | imgname = imgname_batch[img_idx] 94 | save_img = os.path.join(save_path,imgname) 95 | save_image(img,save_img) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | 101 | ''' 102 | CUDA_VISIBLE_DEVICES=0 \ 103 | python3 \ 104 | test.py \ 105 | --config configs/test_cldm.yaml \ 106 | --output /home/zyx/DiffBIR-main/outputs/celebamaskhq_reflow_nolora/ 107 | 108 | CUDA_VISIBLE_DEVICES=6 \ 109 | python test.py \ 110 | --config configs/test_cldm.yaml \ 111 | --output outputs/reflow_celeba_lq_than_lq 112 | ''' -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | import torch 6 | 7 | from utils.common import instantiate_from_config, load_state_dict 8 | 9 | 10 | def main() -> None: 11 | parser = ArgumentParser() 12 | parser.add_argument("--config", type=str, required=True) 13 | args = parser.parse_args() 14 | 15 | config = OmegaConf.load(args.config) 16 | pl.seed_everything(config.lightning.seed, workers=True) 17 | 18 | data_module = instantiate_from_config(config.data) 19 | model = instantiate_from_config(OmegaConf.load(config.model.config)) 20 | # TODO: resume states saved in checkpoint. 21 | if config.model.get("resume"): 22 | weights = torch.load(config.model.resume, map_location="cpu") 23 | '''new_weights = {} 24 | for k in weights['state_dict']: 25 | 26 | 27 | if 'lora' not in k: 28 | new_weights[k] = weights['state_dict'][k] 29 | weights['state_dict'] = new_weights''' 30 | load_state_dict(model, weights, strict=False) 31 | 32 | 33 | #load_state_dict(model.preprocess_model, torch.load('/home/user001/zwl/data/flowir_work_dirs/swin_derain0/lightning_logs/version_6/checkpoints/step=69999.ckpt', map_location="cpu"), strict=True) 34 | 35 | callbacks = [] 36 | for callback_config in config.lightning.callbacks: 37 | callbacks.append(instantiate_from_config(callback_config)) 38 | trainer = pl.Trainer(callbacks=callbacks, **config.lightning.trainer) 39 | trainer.fit(model, datamodule=data_module) 40 | #trainer.test() 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/degradation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/degradation.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/face_restoration_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/face_restoration_helper.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/file.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/file.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/file.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/process.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/process.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any 2 | import importlib 3 | 4 | from torch import nn 5 | 6 | 7 | def get_obj_from_str(string: str, reload: bool=False) -> object: 8 | module, cls = string.rsplit(".", 1) 9 | if reload: 10 | module_imp = importlib.import_module(module) 11 | importlib.reload(module_imp) 12 | return getattr(importlib.import_module(module, package=None), cls) 13 | 14 | 15 | def instantiate_from_config(config: Mapping[str, Any]) -> object: 16 | if not "target" in config: 17 | raise KeyError("Expected key `target` to instantiate.") 18 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 19 | 20 | 21 | def disabled_train(self: nn.Module) -> nn.Module: 22 | """Overwrite model.train with this function to make sure train/eval mode 23 | does not change anymore.""" 24 | return self 25 | 26 | 27 | def frozen_module(module: nn.Module) -> None: 28 | module.eval() 29 | module.train = disabled_train 30 | for p in module.parameters(): 31 | p.requires_grad = False 32 | 33 | 34 | def load_state_dict(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None: 35 | state_dict = state_dict.get("state_dict", state_dict) 36 | 37 | is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.") 38 | is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.") 39 | 40 | if ( 41 | is_model_key_starts_with_module and 42 | (not is_state_dict_key_starts_with_module) 43 | ): 44 | state_dict = {f"module.{key}": value for key, value in state_dict.items()} 45 | if ( 46 | (not is_model_key_starts_with_module) and 47 | is_state_dict_key_starts_with_module 48 | ): 49 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 50 | 51 | model.load_state_dict(state_dict, strict=strict) 52 | -------------------------------------------------------------------------------- /utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | from urllib.parse import urlparse 5 | from torch.hub import download_url_to_file, get_dir 6 | 7 | 8 | def load_file_list(file_list_path: str) -> List[str]: 9 | files = [] 10 | # each line in file list contains a path of an image 11 | with open(file_list_path, "r") as fin: 12 | for line in fin: 13 | path = line.strip() 14 | if path: 15 | files.append(path) 16 | return files 17 | 18 | 19 | def list_image_files( 20 | img_dir: str, 21 | exts: Tuple[str]=(".jpg", ".png", ".jpeg",".arw"), 22 | follow_links: bool=False, 23 | log_progress: bool=False, 24 | log_every_n_files: int=10000, 25 | max_size: int=-1 26 | ) -> List[str]: 27 | files = [] 28 | for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links): 29 | early_stop = False 30 | for file_name in file_names: 31 | if os.path.splitext(file_name)[1].lower() in exts: 32 | if max_size >= 0 and len(files) >= max_size: 33 | early_stop = True 34 | break 35 | files.append(os.path.join(dir_path, file_name)) 36 | if log_progress and len(files) % log_every_n_files == 0: 37 | print(f"find {len(files)} images in {img_dir}") 38 | if early_stop: 39 | break 40 | return files 41 | 42 | 43 | def get_file_name_parts(file_path: str) -> Tuple[str, str, str]: 44 | parent_path, file_name = os.path.split(file_path) 45 | stem, ext = os.path.splitext(file_name) 46 | return parent_path, stem, ext 47 | 48 | 49 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ 50 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 51 | """Load file form http url, will download models if necessary. 52 | 53 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 54 | 55 | Args: 56 | url (str): URL to be downloaded. 57 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 58 | Default: None. 59 | progress (bool): Whether to show the download progress. Default: True. 60 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 61 | 62 | Returns: 63 | str: The path to the downloaded file. 64 | """ 65 | if model_dir is None: # use the pytorch hub_dir 66 | hub_dir = get_dir() 67 | model_dir = os.path.join(hub_dir, 'checkpoints') 68 | 69 | os.makedirs(model_dir, exist_ok=True) 70 | 71 | parts = urlparse(url) 72 | filename = os.path.basename(parts.path) 73 | if file_name is not None: 74 | filename = file_name 75 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 76 | if not os.path.exists(cached_file): 77 | print(f'Downloading: "{url}" to {cached_file}\n') 78 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 79 | return cached_file 80 | -------------------------------------------------------------------------------- /utils/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffjpeg import DiffJPEG 2 | from .usm_sharp import USMSharp 3 | from .common import ( 4 | random_crop_arr, center_crop_arr, augment, 5 | filter2D, rgb2ycbcr_pt, auto_resize, pad 6 | ) 7 | from .align_color import ( 8 | wavelet_reconstruction, adaptive_instance_normalization 9 | ) 10 | 11 | __all__ = [ 12 | "DiffJPEG", 13 | 14 | "USMSharp", 15 | 16 | "random_crop_arr", 17 | "center_crop_arr", 18 | "augment", 19 | "filter2D", 20 | "rgb2ycbcr_pt", 21 | "auto_resize", 22 | "pad", 23 | 24 | "wavelet_reconstruction", 25 | "adaptive_instance_normalization" 26 | ] 27 | -------------------------------------------------------------------------------- /utils/image/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/align_color.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/align_color.cpython-38.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/align_color.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/align_color.cpython-39.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/diffjpeg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/diffjpeg.cpython-38.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/diffjpeg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/diffjpeg.cpython-39.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/usm_sharp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/usm_sharp.cpython-38.pyc -------------------------------------------------------------------------------- /utils/image/__pycache__/usm_sharp.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/usm_sharp.cpython-39.pyc -------------------------------------------------------------------------------- /utils/image/align_color.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------------------------------------------- 3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) 4 | # -------------------------------------------------------------------------------- 5 | ''' 6 | 7 | import torch 8 | from PIL import Image 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | from torchvision.transforms import ToTensor, ToPILImage 12 | 13 | 14 | def adain_color_fix(target: Image, source: Image): 15 | # Convert images to tensors 16 | to_tensor = ToTensor() 17 | target_tensor = to_tensor(target).unsqueeze(0) 18 | source_tensor = to_tensor(source).unsqueeze(0) 19 | 20 | # Apply adaptive instance normalization 21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 22 | 23 | # Convert tensor back to image 24 | to_image = ToPILImage() 25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 26 | 27 | return result_image 28 | 29 | def wavelet_color_fix(target: Image, source: Image): 30 | # Convert images to tensors 31 | to_tensor = ToTensor() 32 | target_tensor = to_tensor(target).unsqueeze(0) 33 | source_tensor = to_tensor(source).unsqueeze(0) 34 | 35 | # Apply wavelet reconstruction 36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 37 | 38 | # Convert tensor back to image 39 | to_image = ToPILImage() 40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 41 | 42 | return result_image 43 | 44 | def calc_mean_std(feat: Tensor, eps=1e-5): 45 | """Calculate mean and std for adaptive_instance_normalization. 46 | Args: 47 | feat (Tensor): 4D tensor. 48 | eps (float): A small value added to the variance to avoid 49 | divide-by-zero. Default: 1e-5. 50 | """ 51 | size = feat.size() 52 | assert len(size) == 4, 'The input feature should be 4D tensor.' 53 | b, c = size[:2] 54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 57 | return feat_mean, feat_std 58 | 59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 60 | """Adaptive instance normalization. 61 | Adjust the reference features to have the similar color and illuminations 62 | as those in the degradate features. 63 | Args: 64 | content_feat (Tensor): The reference feature. 65 | style_feat (Tensor): The degradate features. 66 | """ 67 | size = content_feat.size() 68 | style_mean, style_std = calc_mean_std(style_feat) 69 | content_mean, content_std = calc_mean_std(content_feat) 70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 72 | 73 | def wavelet_blur(image: Tensor, radius: int): 74 | """ 75 | Apply wavelet blur to the input tensor. 76 | """ 77 | # input shape: (1, 3, H, W) 78 | # convolution kernel 79 | kernel_vals = [ 80 | [0.0625, 0.125, 0.0625], 81 | [0.125, 0.25, 0.125], 82 | [0.0625, 0.125, 0.0625], 83 | ] 84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 85 | # add channel dimensions to the kernel to make it a 4D tensor 86 | kernel = kernel[None, None] 87 | # repeat the kernel across all input channels 88 | kernel = kernel.repeat(3, 1, 1, 1) 89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 90 | # apply convolution 91 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 92 | return output 93 | 94 | def wavelet_decomposition(image: Tensor, levels=5): 95 | """ 96 | Apply wavelet decomposition to the input tensor. 97 | This function only returns the low frequency & the high frequency. 98 | """ 99 | high_freq = torch.zeros_like(image) 100 | for i in range(levels): 101 | radius = 2 ** i 102 | low_freq = wavelet_blur(image, radius) 103 | high_freq += (image - low_freq) 104 | image = low_freq 105 | 106 | return high_freq, low_freq 107 | 108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 109 | """ 110 | Apply wavelet decomposition, so that the content will have the same color as the style. 111 | """ 112 | # calculate the wavelet decomposition of the content feature 113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 114 | del content_low_freq 115 | # calculate the wavelet decomposition of the style feature 116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 117 | del style_high_freq 118 | # reconstruct the content feature with the style's high frequency 119 | return content_high_freq + style_low_freq -------------------------------------------------------------------------------- /utils/image/usm_sharp.py: -------------------------------------------------------------------------------- 1 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | from .common import filter2D 7 | 8 | 9 | class USMSharp(torch.nn.Module): 10 | 11 | def __init__(self, radius=50, sigma=0): 12 | super(USMSharp, self).__init__() 13 | if radius % 2 == 0: 14 | radius += 1 15 | self.radius = radius 16 | kernel = cv2.getGaussianKernel(radius, sigma) 17 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 18 | self.register_buffer('kernel', kernel) 19 | 20 | def forward(self, img, weight=0.5, threshold=10): 21 | blur = filter2D(img, self.kernel) 22 | residual = img - blur 23 | 24 | mask = torch.abs(residual) * 255 > threshold 25 | mask = mask.float() 26 | soft_mask = filter2D(mask, self.kernel) 27 | sharp = img + weight * residual 28 | sharp = torch.clip(sharp, 0, 1) 29 | return soft_mask * sharp + (1 - soft_mask) * img 30 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lpips 3 | 4 | from .image import rgb2ycbcr_pt 5 | from .common import frozen_module 6 | 7 | 8 | # https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/metrics/psnr_ssim.py#L52 9 | def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). 11 | 12 | Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | 14 | Args: 15 | img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). 16 | img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). 17 | crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. 18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 19 | 20 | Returns: 21 | float: PSNR result. 22 | """ 23 | 24 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') 25 | 26 | if crop_border != 0: 27 | img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] 28 | img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] 29 | 30 | if test_y_channel: 31 | img = rgb2ycbcr_pt(img, y_only=True) 32 | img2 = rgb2ycbcr_pt(img2, y_only=True) 33 | 34 | img = img.to(torch.float64) 35 | img2 = img2.to(torch.float64) 36 | 37 | mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) 38 | return 10. * torch.log10(1. / (mse + 1e-8)) 39 | 40 | 41 | class LPIPS: 42 | 43 | def __init__(self, net: str) -> None: 44 | self.model = lpips.LPIPS(net=net) 45 | frozen_module(self.model) 46 | 47 | @torch.no_grad() 48 | def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool) -> torch.Tensor: 49 | """ 50 | Compute LPIPS. 51 | 52 | Args: 53 | img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input 54 | image is range in [0, 1]. 55 | img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input 56 | image is range in [0, 1]. 57 | normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1]. 58 | 59 | Returns: 60 | lpips_values (torch.Tensor): The lpips scores of this batch. 61 | """ 62 | return self.model(img1, img2, normalize=normalize) 63 | 64 | def to(self, device: str) -> "LPIPS": 65 | self.model.to(device) 66 | return self 67 | -------------------------------------------------------------------------------- /utils/pickout_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | if __name__ == "__main__": 5 | gttest = "outputs/gtcelebamaskhq" 6 | os.makedirs(gttest,exist_ok=True) 7 | with open("/data1/zyx/FFHQ512/test.list", 'r') as f: 8 | lines = f.readlines() 9 | for line in lines: 10 | path = line[:-1] 11 | name = path.split('/')[-1] 12 | shutil.copy(path,os.path.join(gttest,name)) 13 | 14 | -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | """Forward processing of raw data to sRGB images. 2 | 3 | Unprocessing Images for Learned Raw Denoising 4 | http://timothybrooks.com/tech/unprocessing 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from torchinterp1d import Interp1d 10 | from os.path import join 11 | 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | def apply_gains(bayer_images, wbs): 16 | """Applies white balance to a batch of Bayer images.""" 17 | N, C, _, _ = bayer_images.shape 18 | outs = bayer_images * wbs.view(N, C, 1, 1) 19 | return outs 20 | 21 | 22 | def apply_ccms(images, ccms): 23 | """Applies color correction matrices.""" 24 | images = images.permute( 25 | 0, 2, 3, 1) # Permute the image tensor to BxHxWxC format from BxCxHxW format 26 | images = images[:, :, :, None, :] 27 | ccms = ccms[:, None, None, :, :] 28 | outs = torch.sum(images * ccms, dim=-1) 29 | # Re-Permute the tensor back to BxCxHxW format 30 | outs = outs.permute(0, 3, 1, 2) 31 | return outs 32 | 33 | 34 | def gamma_compression(images, gamma=2.2): 35 | """Converts from linear to gamma space.""" 36 | outs = torch.clamp(images, min=1e-8) ** (1 / gamma) 37 | # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images 38 | outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 39 | return outs 40 | 41 | 42 | def binning(bayer_images): 43 | """RGBG -> RGB""" 44 | lin_rgb = torch.stack([ 45 | bayer_images[:,0,...], 46 | torch.mean(bayer_images[:, [1,3], ...], dim=1), 47 | bayer_images[:,2,...]], dim=1) 48 | 49 | return lin_rgb 50 | 51 | 52 | def process(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None): 53 | """Processes a batch of Bayer RGBG images into sRGB images.""" 54 | # White balance. 55 | bayer_images = apply_gains(bayer_images, wbs) 56 | # Binning 57 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) 58 | images = binning(bayer_images) 59 | # Color correction. 60 | images = apply_ccms(images, cam2rgbs) 61 | # Gamma compression. 62 | images = torch.clamp(images, min=0.0, max=1.0) 63 | if CRF is None: 64 | images = gamma_compression(images, gamma) 65 | else: 66 | images = camera_response_function(images, CRF) 67 | 68 | return images 69 | 70 | 71 | def camera_response_function(images, CRF): 72 | E, fs = CRF # unpack CRF data 73 | 74 | outs = torch.zeros_like(images) 75 | device = images.device 76 | 77 | for i in range(images.shape[0]): 78 | img = images[i].view(3, -1) 79 | out = Interp1d()(E.to(device), fs.to(device), img) 80 | outs[i, ...] = out.view(3, images.shape[2], images.shape[3]) 81 | 82 | outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255 83 | return outs 84 | 85 | 86 | def raw2rgb(packed_raw, raw, CRF=None, gamma=2.2): 87 | """Raw2RGB pipeline (preprocess version)""" 88 | wb = np.array(raw.camera_whitebalance) 89 | wb /= wb[1] 90 | cam2rgb = raw.rgb_camera_matrix[:3, :3] 91 | 92 | if isinstance(packed_raw, np.ndarray): 93 | packed_raw = torch.from_numpy(packed_raw).float() 94 | 95 | wb = torch.from_numpy(wb).float().to(packed_raw.device) 96 | cam2rgb = torch.from_numpy(cam2rgb).float().to(packed_raw.device) 97 | 98 | out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() 99 | 100 | return out 101 | 102 | 103 | def raw2rgb_v2(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG 104 | packed_raw = torch.from_numpy(packed_raw).float() 105 | wb = torch.from_numpy(wb).float() 106 | cam2rgb = torch.from_numpy(ccm).float() 107 | out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy() 108 | return out 109 | 110 | 111 | def raw2rgb_postprocess(packed_raw, raw, CRF=None): 112 | """Raw2RGB pipeline (postprocess version)""" 113 | assert packed_raw.ndimension() == 4 and packed_raw.shape[0] == 1 114 | wb = np.array(raw.camera_whitebalance) 115 | wb /= wb[1] 116 | cam2rgb = raw.rgb_camera_matrix[:3, :3] 117 | 118 | wb = torch.from_numpy(wb[None]).float().to(packed_raw.device) 119 | cam2rgb = torch.from_numpy(cam2rgb[None]).float().to(packed_raw.device) 120 | out = process(packed_raw, wbs=wb, cam2rgbs=cam2rgb, gamma=2.2, CRF=CRF) 121 | return out 122 | 123 | 124 | def read_wb_ccm(raw): 125 | wb = np.array(raw.camera_whitebalance) 126 | wb /= wb[1] 127 | wb = wb.astype(np.float32) 128 | ccm = raw.rgb_camera_matrix[:3, :3].astype(np.float32) 129 | return wb, ccm 130 | 131 | 132 | def read_emor(address): 133 | def _read_curve(lst): 134 | curve = [l.strip() for l in lst] 135 | curve = ' '.join(curve) 136 | curve = np.array(curve.split()).astype(np.float32) 137 | return curve 138 | 139 | with open(address) as f: 140 | lines = f.readlines() 141 | k = 1 142 | E = _read_curve(lines[k:k+256]) 143 | k += 257 144 | f0 = _read_curve(lines[k:k+256]) 145 | hs = [] 146 | for _ in range(25): 147 | k += 257 148 | hs.append(_read_curve(lines[k:k+256])) 149 | 150 | hs = np.array(hs) 151 | 152 | return E, f0, hs 153 | 154 | 155 | def read_dorf(address): 156 | with open(address) as f: 157 | lines = f.readlines() 158 | curve_names = lines[0::6] 159 | Es = lines[3::6] 160 | Bs = lines[5::6] 161 | 162 | Es = [np.array(E.strip().split()).astype(np.float32) for E in Es] 163 | Bs = [np.array(B.strip().split()).astype(np.float32) for B in Bs] 164 | 165 | return curve_names, Es, Bs 166 | 167 | 168 | def load_CRF(): 169 | # init CRF function 170 | fs = np.loadtxt(join('EMoR', 'CRF_SonyA7S2_5.txt')) 171 | E, _, _ = read_emor(join('EMoR', 'emor.txt')) 172 | E = torch.from_numpy(E).repeat(3, 1) 173 | fs = torch.from_numpy(fs) 174 | CRF = (E, fs) 175 | return CRF -------------------------------------------------------------------------------- /utils/realesrgan/__pycache__/realesrganer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/realesrgan/__pycache__/realesrganer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/realesrgan/__pycache__/rrdbnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/realesrgan/__pycache__/rrdbnet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/realesrgan/rrdbnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn import init as init 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | 8 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 9 | """Initialize network weights. 10 | 11 | Args: 12 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 13 | scale (float): Scale initialized weights, especially for residual 14 | blocks. Default: 1. 15 | bias_fill (float): The value to fill bias. Default: 0 16 | kwargs (dict): Other arguments for initialization function. 17 | """ 18 | if not isinstance(module_list, list): 19 | module_list = [module_list] 20 | for module in module_list: 21 | for m in module.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | init.kaiming_normal_(m.weight, **kwargs) 24 | m.weight.data *= scale 25 | if m.bias is not None: 26 | m.bias.data.fill_(bias_fill) 27 | elif isinstance(m, nn.Linear): 28 | init.kaiming_normal_(m.weight, **kwargs) 29 | m.weight.data *= scale 30 | if m.bias is not None: 31 | m.bias.data.fill_(bias_fill) 32 | elif isinstance(m, _BatchNorm): 33 | init.constant_(m.weight, 1) 34 | if m.bias is not None: 35 | m.bias.data.fill_(bias_fill) 36 | 37 | 38 | def make_layer(basic_block, num_basic_block, **kwarg): 39 | """Make layers by stacking the same blocks. 40 | 41 | Args: 42 | basic_block (nn.module): nn.module class for basic block. 43 | num_basic_block (int): number of blocks. 44 | 45 | Returns: 46 | nn.Sequential: Stacked blocks in nn.Sequential. 47 | """ 48 | layers = [] 49 | for _ in range(num_basic_block): 50 | layers.append(basic_block(**kwarg)) 51 | return nn.Sequential(*layers) 52 | 53 | 54 | # TODO: may write a cpp file 55 | def pixel_unshuffle(x, scale): 56 | """ Pixel unshuffle. 57 | 58 | Args: 59 | x (Tensor): Input feature with shape (b, c, hh, hw). 60 | scale (int): Downsample ratio. 61 | 62 | Returns: 63 | Tensor: the pixel unshuffled feature. 64 | """ 65 | b, c, hh, hw = x.size() 66 | out_channel = c * (scale**2) 67 | assert hh % scale == 0 and hw % scale == 0 68 | h = hh // scale 69 | w = hw // scale 70 | x_view = x.view(b, c, h, scale, w, scale) 71 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 72 | 73 | 74 | class ResidualDenseBlock(nn.Module): 75 | """Residual Dense Block. 76 | 77 | Used in RRDB block in ESRGAN. 78 | 79 | Args: 80 | num_feat (int): Channel number of intermediate features. 81 | num_grow_ch (int): Channels for each growth. 82 | """ 83 | 84 | def __init__(self, num_feat=64, num_grow_ch=32): 85 | super(ResidualDenseBlock, self).__init__() 86 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 87 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 88 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 89 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 90 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 91 | 92 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 93 | 94 | # initialization 95 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 96 | 97 | def forward(self, x): 98 | x1 = self.lrelu(self.conv1(x)) 99 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 100 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 101 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 102 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 103 | # Empirically, we use 0.2 to scale the residual for better performance 104 | return x5 * 0.2 + x 105 | 106 | 107 | class RRDB(nn.Module): 108 | """Residual in Residual Dense Block. 109 | 110 | Used in RRDB-Net in ESRGAN. 111 | 112 | Args: 113 | num_feat (int): Channel number of intermediate features. 114 | num_grow_ch (int): Channels for each growth. 115 | """ 116 | 117 | def __init__(self, num_feat, num_grow_ch=32): 118 | super(RRDB, self).__init__() 119 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 120 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 121 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 122 | 123 | def forward(self, x): 124 | out = self.rdb1(x) 125 | out = self.rdb2(out) 126 | out = self.rdb3(out) 127 | # Empirically, we use 0.2 to scale the residual for better performance 128 | return out * 0.2 + x 129 | 130 | 131 | class RRDBNet(nn.Module): 132 | """Networks consisting of Residual in Residual Dense Block, which is used 133 | in ESRGAN. 134 | 135 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 136 | 137 | We extend ESRGAN for scale x2 and scale x1. 138 | Note: This is one option for scale 1, scale 2 in RRDBNet. 139 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 140 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 141 | 142 | Args: 143 | num_in_ch (int): Channel number of inputs. 144 | num_out_ch (int): Channel number of outputs. 145 | num_feat (int): Channel number of intermediate features. 146 | Default: 64 147 | num_block (int): Block number in the trunk network. Defaults: 23 148 | num_grow_ch (int): Channels for each growth. Default: 32. 149 | """ 150 | 151 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 152 | super(RRDBNet, self).__init__() 153 | self.scale = scale 154 | if scale == 2: 155 | num_in_ch = num_in_ch * 4 156 | elif scale == 1: 157 | num_in_ch = num_in_ch * 16 158 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 159 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 160 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 161 | # upsample 162 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 163 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 164 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 165 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 166 | 167 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 168 | 169 | def forward(self, x): 170 | if self.scale == 2: 171 | feat = pixel_unshuffle(x, scale=2) 172 | elif self.scale == 1: 173 | feat = pixel_unshuffle(x, scale=4) 174 | else: 175 | feat = x 176 | feat = self.conv_first(feat) 177 | body_feat = self.conv_body(self.body(feat)) 178 | feat = feat + body_feat 179 | # upsample 180 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 181 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 182 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 183 | return out -------------------------------------------------------------------------------- /utils/torchinterp1d-master.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/torchinterp1d-master.zip -------------------------------------------------------------------------------- /utils/torchinterp1d/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /utils/torchinterp1d/.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/torchinterp1d/.gitmodules -------------------------------------------------------------------------------- /utils/torchinterp1d/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Inria (Antoine Liutkus) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /utils/torchinterp1d/README.md: -------------------------------------------------------------------------------- 1 | # torchinterp1d 2 | ## CUDA 1-D interpolation for Pytorch 3 | 4 | Requires PyTorch >= 1.6 (due to [torch.searchsorted](https://pytorch.org/docs/master/generated/torch.searchsorted.html)). 5 | 6 | ## Presentation 7 | 8 | This repository implements an `interp1d` function that overrides torch.autograd.Function, enabling 9 | linear 1D interpolation on the GPU for Pytorch. 10 | 11 | ``` 12 | def interp1d(x, y, xnew, out=None) 13 | ``` 14 | 15 | This function returns interpolated values of a set of 1-D functions at the desired query points `xnew`. 16 | 17 | It works similarly to Matlab™ or scipy functions with 18 | the `linear` interpolation mode on, except that it parallelises over any number of desired interpolation problems and exploits CUDA on the GPU 19 | 20 | ### Parameters for `interp1d` 21 | 22 | * `x` : a (N, ) or (D, N) Pytorch Tensor: 23 | Either 1-D or 2-D. It contains the coordinates of the observed samples. 24 | 25 | * `y` : (N,) or (D, N) Pytorch Tensor. 26 | Either 1-D or 2-D. It contains the actual values that correspond to the coordinates given by `x`. 27 | The length of `y` along its last dimension must be the same as that of `x` 28 | 29 | * `xnew` : (P,) or (D, P) Pytorch Tensor. 30 | Either 1-D or 2-D. If it is not 1-D, its length along the first dimension must be the same as that of whichever `x` and `y` is 2-D. x-coordinates for which we want the interpolated output. 31 | 32 | * `out` : (D, P) Pytorch Tensor` 33 | Tensor for the output. If None: allocated automatically. 34 | 35 | ### Results 36 | 37 | a Pytorch tensor of shape (D, P), containing the interpolated values. 38 | 39 | ## Installation 40 | 41 | Type `pip install -e .` in the root folder of this repo. 42 | 43 | ## Usage 44 | 45 | Basically simply calle `torchinterp1d.interp1d`. 46 | 47 | Try out `python test.py` in the `examples` folder. 48 | ``` 49 | Solving 100000 interpolation problems: each with 100 observations and 30 desired values 50 | CPU: 8060.260ms, GPU: 70.735ms, error: 0.000000%. 51 | ``` 52 | -------------------------------------------------------------------------------- /utils/torchinterp1d/examples/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import time 4 | import numpy as np 5 | from torchinterp1d import interp1d 6 | 7 | 8 | if __name__ == "__main__": 9 | # defining the number of tests 10 | ntests = 2 11 | 12 | # problem dimensions 13 | D = 1000 14 | Dnew = 1 15 | N = 100 16 | P = 30 17 | 18 | yq_gpu = None 19 | yq_cpu = None 20 | for ntest in range(ntests): 21 | # draw the data 22 | x = torch.rand(D, N) * 10000 23 | x = x.sort(dim=1)[0] 24 | 25 | y = torch.linspace(0, 1000, D*N).view(D, -1) 26 | y -= y[:, 0, None] 27 | 28 | xnew = torch.rand(Dnew, P)*10000 29 | 30 | print('Solving %d interpolation problems: ' 31 | 'each with %d observations and %d desired values' % (D, N, P)) 32 | 33 | # calling the cpu version 34 | t0_cpu = time.time() 35 | yq_cpu = interp1d(x, y, xnew, yq_cpu) 36 | t1_cpu = time.time() 37 | 38 | display_str = 'CPU: %0.3fms, ' % ((t1_cpu-t0_cpu)*1000) 39 | 40 | if torch.cuda.is_available(): 41 | x = x.to('cuda') 42 | y = y.to('cuda') 43 | xnew = xnew.to('cuda') 44 | 45 | # launching the cuda version 46 | t0 = time.time() 47 | yq_gpu = interp1d(x, y, xnew, yq_gpu) 48 | t1 = time.time() 49 | 50 | # compute the difference between both 51 | error = torch.norm( 52 | yq_cpu - yq_gpu.to('cpu'))/torch.norm(yq_cpu)*100. 53 | 54 | display_str += 'GPU: %0.3fms, error: %f%%.' % ( 55 | (t1-t0)*1000, error) 56 | print(display_str) 57 | 58 | if torch.cuda.is_available(): 59 | # for the last test, plot the result for the first 10 dimensions max 60 | d_plot = min(D, 10) 61 | x = x[:d_plot].cpu().numpy() 62 | y = y[:d_plot].cpu().numpy() 63 | xnew = xnew[:d_plot].cpu().numpy() 64 | yq_cpu = yq_cpu[:d_plot].cpu().numpy() 65 | yq_gpu = yq_gpu[:d_plot].cpu().numpy() 66 | 67 | plt.plot(x.T, y.T, '-', 68 | xnew.T, yq_gpu.T, 'o', 69 | xnew.T, yq_cpu.T, 'x') 70 | not_close = np.nonzero(np.invert(np.isclose(yq_gpu, yq_cpu))) 71 | if not_close[0].size: 72 | plt.scatter(xnew[not_close].T, yq_cpu[not_close].T, 73 | edgecolors='r', s=100, facecolors='none') 74 | plt.grid(True) 75 | plt.show() 76 | -------------------------------------------------------------------------------- /utils/torchinterp1d/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | # To use a consistent encoding 3 | from codecs import open 4 | from os import path 5 | 6 | # trying to import the required torch package 7 | try: 8 | import torch 9 | except ImportError: 10 | raise Exception('qsketch requires PyTorch to be installed. aborting') 11 | 12 | here = path.abspath(path.dirname(__file__)) 13 | 14 | # Get the long description from the README file 15 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 16 | long_description = f.read() 17 | 18 | # Proceed to setup 19 | setup( 20 | name='torchinterp1d', 21 | version='1.1', 22 | description='An interp1d implementation for pytorch', 23 | long_description=long_description, 24 | long_description_content_type='text/markdown', 25 | author='Antoine Liutkus', 26 | author_email='antoine.liutkus@inria.fr', 27 | packages=['torchinterp1d'], 28 | keywords='interp1d torch', 29 | install_requires=[ 30 | 'torch>=1.6', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /utils/torchinterp1d/torchinterp1d/__init__.py: -------------------------------------------------------------------------------- 1 | from .interp1d import interp1d, Interp1d 2 | -------------------------------------------------------------------------------- /utils/torchinterp1d/torchinterp1d/interp1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import contextlib 3 | 4 | class Interp1d(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x, y, xnew, out=None): 7 | """ 8 | Linear 1D interpolation on the GPU for Pytorch. 9 | This function returns interpolated values of a set of 1-D functions at 10 | the desired query points `xnew`. 11 | This function is working similarly to Matlab™ or scipy functions with 12 | the `linear` interpolation mode on, except that it parallelises over 13 | any number of desired interpolation problems. 14 | The code will run on GPU if all the tensors provided are on a cuda 15 | device. 16 | 17 | Parameters 18 | ---------- 19 | x : (N, ) or (D, N) Pytorch Tensor 20 | A 1-D or 2-D tensor of real values. 21 | y : (N,) or (D, N) Pytorch Tensor 22 | A 1-D or 2-D tensor of real values. The length of `y` along its 23 | last dimension must be the same as that of `x` 24 | xnew : (P,) or (D, P) Pytorch Tensor 25 | A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if 26 | _both_ `x` and `y` are 1-D. Otherwise, its length along the first 27 | dimension must be the same as that of whichever `x` and `y` is 2-D. 28 | out : Pytorch Tensor, same shape as `xnew` 29 | Tensor for the output. If None: allocated automatically. 30 | 31 | """ 32 | # making the vectors at least 2D 33 | is_flat = {} 34 | require_grad = {} 35 | v = {} 36 | device = [] 37 | eps = torch.finfo(y.dtype).eps 38 | for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items(): 39 | assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\ 40 | 'at most 2-D.' 41 | if len(vec.shape) == 1: 42 | v[name] = vec[None, :] 43 | else: 44 | v[name] = vec 45 | is_flat[name] = v[name].shape[0] == 1 46 | require_grad[name] = vec.requires_grad 47 | device = list(set(device + [str(vec.device)])) 48 | assert len(device) == 1, 'All parameters must be on the same device.' 49 | device = device[0] 50 | 51 | # Checking for the dimensions 52 | assert (v['x'].shape[1] == v['y'].shape[1] 53 | and ( 54 | v['x'].shape[0] == v['y'].shape[0] 55 | or v['x'].shape[0] == 1 56 | or v['y'].shape[0] == 1 57 | ) 58 | ), ("x and y must have the same number of columns, and either " 59 | "the same number of row or one of them having only one " 60 | "row.") 61 | 62 | reshaped_xnew = False 63 | if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1) 64 | and (v['xnew'].shape[0] > 1)): 65 | # if there is only one row for both x and y, there is no need to 66 | # loop over the rows of xnew because they will all have to face the 67 | # same interpolation problem. We should just stack them together to 68 | # call interp1d and put them back in place afterwards. 69 | original_xnew_shape = v['xnew'].shape 70 | v['xnew'] = v['xnew'].contiguous().view(1, -1) 71 | reshaped_xnew = True 72 | 73 | # identify the dimensions of output and check if the one provided is ok 74 | D = max(v['x'].shape[0], v['xnew'].shape[0]) 75 | shape_ynew = (D, v['xnew'].shape[-1]) 76 | if out is not None: 77 | if out.numel() != shape_ynew[0]*shape_ynew[1]: 78 | # The output provided is of incorrect shape. 79 | # Going for a new one 80 | out = None 81 | else: 82 | ynew = out.reshape(shape_ynew) 83 | if out is None: 84 | ynew = torch.zeros(*shape_ynew, device=device) 85 | 86 | # moving everything to the desired device in case it was not there 87 | # already (not handling the case things do not fit entirely, user will 88 | # do it if required.) 89 | for name in v: 90 | v[name] = v[name].to(device) 91 | 92 | # calling searchsorted on the x values. 93 | ind = ynew.long() 94 | 95 | # expanding xnew to match the number of rows of x in case only one xnew is 96 | # provided 97 | if v['xnew'].shape[0] == 1: 98 | v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1) 99 | 100 | # the squeeze is because torch.searchsorted does accept either a nd with 101 | # matching shapes for x and xnew or a 1d vector for x. Here we would 102 | # have (1,len) for x sometimes 103 | torch.searchsorted(v['x'].contiguous().squeeze(), 104 | v['xnew'].contiguous(), out=ind) 105 | 106 | # the `-1` is because searchsorted looks for the index where the values 107 | # must be inserted to preserve order. And we want the index of the 108 | # preceeding value. 109 | ind -= 1 110 | # we clamp the index, because the number of intervals is x.shape-1, 111 | # and the left neighbour should hence be at most number of intervals 112 | # -1, i.e. number of columns in x -2 113 | ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1) 114 | 115 | # helper function to select stuff according to the found indices. 116 | def sel(name): 117 | if is_flat[name]: 118 | return v[name].contiguous().view(-1)[ind] 119 | return torch.gather(v[name], 1, ind) 120 | 121 | # activating gradient storing for everything now 122 | enable_grad = False 123 | saved_inputs = [] 124 | for name in ['x', 'y', 'xnew']: 125 | if require_grad[name]: 126 | enable_grad = True 127 | saved_inputs += [v[name]] 128 | else: 129 | saved_inputs += [None, ] 130 | # assuming x are sorted in the dimension 1, computing the slopes for 131 | # the segments 132 | is_flat['slopes'] = is_flat['x'] 133 | # now we have found the indices of the neighbors, we start building the 134 | # output. Hence, we start also activating gradient tracking 135 | with torch.enable_grad() if enable_grad else contextlib.suppress(): 136 | v['slopes'] = ( 137 | (v['y'][:, 1:]-v['y'][:, :-1]) 138 | / 139 | (eps + (v['x'][:, 1:]-v['x'][:, :-1])) 140 | ) 141 | 142 | # now build the linear interpolation 143 | ynew = sel('y') + sel('slopes')*( 144 | v['xnew'] - sel('x')) 145 | 146 | if reshaped_xnew: 147 | ynew = ynew.view(original_xnew_shape) 148 | 149 | ctx.save_for_backward(ynew, *saved_inputs) 150 | return ynew 151 | 152 | @staticmethod 153 | def backward(ctx, grad_out): 154 | inputs = ctx.saved_tensors[1:] 155 | gradients = torch.autograd.grad( 156 | ctx.saved_tensors[0], 157 | [i for i in inputs if i is not None], 158 | grad_out, retain_graph=True) 159 | result = [None, ] * 5 160 | pos = 0 161 | for index in range(len(inputs)): 162 | if inputs[index] is not None: 163 | result[index] = gradients[pos] 164 | pos += 1 165 | return (*result,) 166 | 167 | 168 | interp1d = Interp1d.apply -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import math 6 | import torch 7 | import numpy as np 8 | import scipy 9 | import scipy.io as spio 10 | import yaml 11 | from PIL import Image 12 | import torch.nn.functional as F 13 | 14 | def loadmat(filename): 15 | ''' 16 | this function should be called instead of direct spio.loadmat 17 | as it cures the problem of not properly recovering python dictionaries 18 | from mat files. It calls the function check keys to cure all entries 19 | which are still mat-objects 20 | ''' 21 | def _check_keys(d): 22 | ''' 23 | checks if entries in dictionary are mat-objects. If yes 24 | todict is called to change them to nested dictionaries 25 | ''' 26 | for key in d: 27 | if isinstance(d[key], spio.matlab.mio5_params.mat_struct): 28 | d[key] = _todict(d[key]) 29 | return d 30 | 31 | def _todict(matobj): 32 | ''' 33 | A recursive function which constructs from matobjects nested dictionaries 34 | ''' 35 | d = {} 36 | for strg in matobj._fieldnames: 37 | elem = matobj.__dict__[strg] 38 | if isinstance(elem, spio.matlab.mio5_params.mat_struct): 39 | d[strg] = _todict(elem) 40 | elif isinstance(elem, np.ndarray): 41 | d[strg] = _tolist(elem) 42 | else: 43 | d[strg] = elem 44 | return d 45 | 46 | def _tolist(ndarray): 47 | ''' 48 | A recursive function which constructs lists from cellarrays 49 | (which are loaded as numpy ndarrays), recursing into the elements 50 | if they contain matobjects. 51 | ''' 52 | elem_list = [] 53 | for sub_elem in ndarray: 54 | if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct): 55 | elem_list.append(_todict(sub_elem)) 56 | elif isinstance(sub_elem, np.ndarray): 57 | elem_list.append(_tolist(sub_elem)) 58 | else: 59 | elem_list.append(sub_elem) 60 | return elem_list 61 | data = scipy.io.loadmat(filename, struct_as_record=False, squeeze_me=True) 62 | return _check_keys(data) -------------------------------------------------------------------------------- /weights/null_token_1024.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/weights/null_token_1024.pth --------------------------------------------------------------------------------