├── LICENSE
├── README.md
├── assets
├── faces
│ ├── 00000014.png
│ └── 00002386.png
├── pipeline.png
└── real-photos
│ └── dog.png
├── configs
├── dataset
│ ├── dehaze_train.yaml
│ ├── dehaze_val.yaml
│ ├── derain_train.yaml
│ ├── derain_val.yaml
│ ├── enlight_train.yaml
│ ├── enlight_val.yaml
│ ├── face_test_inp.yaml
│ ├── face_test_lq.yaml
│ ├── face_train.yaml
│ ├── face_val.yaml
│ ├── general_deg_codeformer_train.yaml
│ ├── general_deg_codeformer_val.yaml
│ ├── general_deg_realesrgan_train.yaml
│ ├── general_deg_realesrgan_val.yaml
│ ├── general_test_lq.yaml
│ ├── ocr_train.yaml
│ └── ocr_val.yaml
├── model
│ ├── cldm.yaml
│ ├── cldm_bsr_eval.yaml
│ ├── cldm_eval.yaml
│ └── swinir.yaml
├── test_cldm.yaml
├── test_cldm_general.yaml
├── train_cldm.yaml
├── train_cldm_general.yaml
└── train_swinir.yaml
├── dataset
├── __pycache__
│ ├── batch_transform.cpython-38.pyc
│ ├── codeformer.cpython-38.pyc
│ └── data_module.cpython-38.pyc
├── batch_transform.py
├── codeformer.py
├── data_module.py
├── realesrgan.py
└── test.py
├── evaluate.py
├── inference.py
├── inference_bfr.py
├── inference_bsr.py
├── ldm
├── __pycache__
│ ├── util.cpython-38.pyc
│ ├── util.cpython-39.pyc
│ ├── xformers_state.cpython-38.pyc
│ └── xformers_state.cpython-39.pyc
├── data
│ ├── __init__.py
│ └── util.py
├── models
│ ├── __pycache__
│ │ ├── autoencoder.cpython-38.pyc
│ │ └── autoencoder.cpython-39.pyc
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── ddim.cpython-38.pyc
│ │ ├── ddim.cpython-39.pyc
│ │ ├── ddpm.cpython-38.pyc
│ │ └── ddpm.cpython-39.pyc
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ ├── plms.py
│ │ └── sampling_util.py
├── modules
│ ├── __pycache__
│ │ ├── attention.cpython-38.pyc
│ │ ├── attention.cpython-39.pyc
│ │ ├── ema.cpython-38.pyc
│ │ └── ema.cpython-39.pyc
│ ├── attention.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── model.cpython-38.pyc
│ │ │ ├── model.cpython-39.pyc
│ │ │ ├── openaimodel.cpython-38.pyc
│ │ │ ├── openaimodel.cpython-39.pyc
│ │ │ ├── util.cpython-38.pyc
│ │ │ └── util.cpython-39.pyc
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── upscaling.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── distributions.cpython-38.pyc
│ │ │ └── distributions.cpython-39.pyc
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── modules.cpython-38.pyc
│ │ │ └── modules.cpython-39.pyc
│ │ └── modules.py
│ └── midas
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── blocks.py
│ │ ├── dpt_depth.py
│ │ ├── midas_net.py
│ │ ├── midas_net_custom.py
│ │ ├── transforms.py
│ │ └── vit.py
│ │ └── utils.py
├── util.py
└── xformers_state.py
├── model
├── __pycache__
│ ├── callbacks.cpython-38.pyc
│ ├── cldm.cpython-38.pyc
│ ├── cldm_bsr.cpython-39.pyc
│ ├── cond_fn.cpython-38.pyc
│ ├── mixins.cpython-38.pyc
│ ├── mixins.cpython-39.pyc
│ ├── spaced_sampler.cpython-38.pyc
│ ├── swinir.cpython-38.pyc
│ └── swinir.cpython-39.pyc
├── callbacks.py
├── cldm.py
├── cldm_bsr.py
├── cond_fn.py
├── mixins.py
├── spaced_sampler.py
└── swinir.py
├── requirements.txt
├── scripts
├── inference_stage1.py
├── make_file_list.py
├── make_list_celea.py
├── make_stage2_init_weight.py
├── merge_img.py
├── metrics.py
├── rainy.py
└── sample_dataset.py
├── test.py
├── train.py
├── utils
├── __pycache__
│ ├── common.cpython-38.pyc
│ ├── common.cpython-39.pyc
│ ├── degradation.cpython-38.pyc
│ ├── face_restoration_helper.cpython-38.pyc
│ ├── file.cpython-310.pyc
│ ├── file.cpython-38.pyc
│ ├── file.cpython-39.pyc
│ ├── metrics.cpython-38.pyc
│ ├── metrics.cpython-39.pyc
│ ├── process.cpython-38.pyc
│ └── utils.cpython-38.pyc
├── common.py
├── degradation.py
├── face_restoration_helper.py
├── file.py
├── image
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── align_color.cpython-38.pyc
│ │ ├── align_color.cpython-39.pyc
│ │ ├── common.cpython-38.pyc
│ │ ├── common.cpython-39.pyc
│ │ ├── diffjpeg.cpython-38.pyc
│ │ ├── diffjpeg.cpython-39.pyc
│ │ ├── usm_sharp.cpython-38.pyc
│ │ └── usm_sharp.cpython-39.pyc
│ ├── align_color.py
│ ├── common.py
│ ├── diffjpeg.py
│ └── usm_sharp.py
├── metrics.py
├── pickout_img.py
├── process.py
├── realesrgan
│ ├── __pycache__
│ │ ├── realesrganer.cpython-38.pyc
│ │ └── rrdbnet.cpython-38.pyc
│ ├── realesrganer.py
│ └── rrdbnet.py
├── torchinterp1d-master.zip
├── torchinterp1d
│ ├── .gitignore
│ ├── .gitmodules
│ ├── LICENSE
│ ├── README.md
│ ├── examples
│ │ └── test.py
│ ├── setup.py
│ └── torchinterp1d
│ │ ├── __init__.py
│ │ └── interp1d.py
└── utils.py
└── weights
└── null_token_1024.pth
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yixuan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FlowIE: Efficient Image Enhancement via Rectified Flow (CVPR 2024)
2 |
3 | > [Yixuan Zhu](https://eternalevan.github.io/)\*, [Wenliang Zhao](https://wl-zhao.github.io/)\* $\dagger$, [Ao Li](https://rammusleo.github.io/), [Yansong Tang](https://andytang15.github.io/), [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu](http://ivg.au.tsinghua.edu.cn/Jiwen_Lu/) $\ddagger$
4 | >
5 | > \* Equal contribution $\dagger$ Project leader $\ddagger$ Corresponding author
6 |
7 | [**[Paper]**](https://arxiv.org/abs/2406.00508)
8 |
9 | The repository contains the official implementation for the paper "FlowIE: Efficient Image Enhancement via Rectified Flow" (**CVPR 2024, oral presentation**).
10 |
11 | FlowIE is a simple yet highly effective **Flow**-based **I**mage **E**nhancement framework that estimates straight-line paths from an elementary distribution to high-quality images.
12 | ## 📋 To-Do List
13 |
14 | * [x] Release model and inference code.
15 | * [x] Release code for training dataloader.
16 |
17 |
18 | ## 💡 Pipeline
19 |
20 | 
21 |
22 |
24 |
25 |
26 | ## 😀Quick Start
27 | ### ⚙️ 1. Installation
28 |
29 | We recommend you to use an [Anaconda](https://www.anaconda.com/) virtual environment. If you have installed Anaconda, run the following commands to create and activate a virtual environment.
30 | ``` bash
31 | conda env create -f requirements.txt
32 | conda activate FlowIE
33 | ```
34 | ### 📑 2. Modify the lora configuration
35 | Since we use `MemoryEfficientCrossAttention` to accelerate the inference process, we need to slightly modify the `lora.py` in lora_diffusion package, which could be done in 2 minutes:
36 | - (1) Locate the `lora.py` file in the package directory. You can easily find this file by using the "go to definition" button in Line 4 of the `./model/cldm.py` file.
37 | - (2) Make the following modifications to Lines 159-161 in `lora.py`:
38 |
39 | Original Code:
40 | ```python
41 | UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
42 | UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
43 | ```
44 |
45 | Modified Code:
46 | ```python
47 | UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU", "MemoryEfficientCrossAttention"}
48 | UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU", "MemoryEfficientCrossAttention", "ResBlock"}
49 | ```
50 |
51 | ### 💾 2. Data Preparation
52 |
53 | We prepare the data in a samilar way as [GFPGAN](https://xinntao.github.io/projects/gfpgan) & [DiffBIR](https://github.com/XPixelGroup/DiffBIR). We list the datasets for BFR and BSR as follows:
54 |
55 | For BFR evaluation, please refer to [here](https://xinntao.github.io/projects/gfpgan) for *BFR-test datasets*, which include *CelebA-Test*, *CelebChild-Test* and *LFW-Test*. The *WIDER-Test* can be found in [here](https://drive.google.com/file/d/1g05U86QGqnlN_v9SRRKDTU8033yvQNEa/view). For BFR training, please download the [FFHQ dataset](https://github.com/NVlabs/ffhq-dataset).
56 |
57 | For BSR, we utilize [ImageNet](https://www.image-net.org/index.php) for training. For evaluation, you can refer to [BSRGAN](https://github.com/cszn/BSRGAN/tree/main/testsets) for *RealSRSet*.
58 |
59 | To prepare the training list, you need to simply run the script:
60 | ```bash
61 | python ./scripts/make_file_list.py --img_folder /data/ILSVRC2012 --save_folder ./dataset/list/imagenet
62 | python ./scripts/make_file_list.py --img_folder /data/FFHQ --save_folder ./dataset/list/ffhq
63 | ```
64 | The file list looks like this:
65 | ```bash
66 | /path/to/image_1.png
67 | /path/to/image_2.png
68 | /path/to/image_3.png
69 | ...
70 | ```
71 | ### 🗂️ 3. Download Checkpoints
72 |
73 | Please download our pretrained checkpoints from [this link](https://cloud.tsinghua.edu.cn/d/4fa2a0880a9243999561/) and put them under `./weights`. The file directory should be:
74 |
75 | ```
76 | |-- checkpoints
77 | |--|-- FlowIE_bfr_v1.ckpt
78 | |--|-- FlowIE_bsr_v1.ckpt
79 | ...
80 | ```
81 |
82 | ### 📊 4. Test & Evaluation
83 |
84 | You can test FlowIE with following commands:
85 | - **Evaluation for BFR**
86 | ```bash
87 | python inference_bfr.py --ckpt ./weights/FlowIE_bfr_v1.ckpt --has_aligned --input /data/celeba_512_validation_lq/ --output ./outputs/bfr_exp --has_aligned
88 | ```
89 | - **Evaluation for BSR**
90 | ```bash
91 | python inference_bsr.py --ckpt ./weights/FlowIE_bsr_v1.ckpt --input /data/testdata/ --output ./outputs/bsr_exp --sr_scale 4
92 | ```
93 | - **Quick Test**
94 |
95 | For a quick test, we collect some test samples in `./assets`. You can run the demo for BFR:
96 | ```bash
97 | python inference_bfr.py --ckpt ./weights/FlowIE_bfr_v1.ckpt --input ./assets/faces --output ./outputs/demo
98 | ```
99 | And for BSR:
100 | ```bash
101 | python inference_bsr.py --ckpt ./weights/FlowIE_bsr_v1.pth --input ./assets/real-photos/ --output ./outputs/bsr_exp --tiled --sr_scale 4
102 | ```
103 | You can use `--tiled` for patch-based inference and use `--sr_scale` tp set the super-resolution scale, like 2 or 4. You can set `CUDA_VISIBLE_DEVICES=1` to choose the devices.
104 |
114 |
115 | The evaluation process can be done with one Nvidia GeForce RTX 3090 GPU (24GB VRAM). You can use more GPUs by specifying the GPU ids.
116 |
117 | ### 🔥 5. Training
118 | The key component in FlowIE is a path estimator tuned from [Stable Diffusion v2.1 base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). Please download it to `./weights`. Another part is the initial module, which can be found in [checkpoints](https://cloud.tsinghua.edu.cn/d/4fa2a0880a9243999561/).
119 |
120 | Before training, you also need to configure training-related information in `./configs/train_cldm.yaml`. Then run this command to start training:
121 | ```bash
122 | python train.py --config ./configs/train_cldm.yaml
123 | ```
124 |
125 |
143 | ## 🫰 Acknowledgments
144 |
145 | We would like to express our sincere thanks to the author of [DiffBIR](https://github.com/XPixelGroup/DiffBIR) for the clear code base and quick response to our issues.
146 |
147 | We also thank [CodeFormer](https://github.com/sczhou/CodeFormer), [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) and [LoRA](https://github.com/cloneofsimo/lora), for our code is partially borrowing from them.
148 |
149 | The new version of FlowIE based on Denoising Transformer (DiT) structure will be released soon! Thanks the newest works of DiTs, including [PixART](https://github.com/PixArt-alpha/PixArt-sigma) and [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium).
150 |
151 | ## 🔖 Citation
152 | Please cite us if our work is useful for your research.
153 |
154 | ```
155 | @misc{zhu2024flowie,
156 | title={FlowIE: Efficient Image Enhancement via Rectified Flow},
157 | author={Yixuan Zhu and Wenliang Zhao and Ao Li and Yansong Tang and Jie Zhou and Jiwen Lu},
158 | year={2024},
159 | eprint={2406.00508},
160 | archivePrefix={arXiv},
161 | primaryClass={cs.CV}
162 | }
163 | ```
164 | ## 🔑 License
165 |
166 | This code is distributed under an [MIT LICENSE](./LICENSE).
167 |
--------------------------------------------------------------------------------
/assets/faces/00000014.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/faces/00000014.png
--------------------------------------------------------------------------------
/assets/faces/00002386.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/faces/00002386.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/real-photos/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/assets/real-photos/dog.png
--------------------------------------------------------------------------------
/configs/dataset/dehaze_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Dehaze
3 | params:
4 | # Path to the file list.
5 | is_val: False
6 | out_size: 512
7 | crop_type: random
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 8
20 | shuffle: true
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/dehaze_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Dehaze
3 | params:
4 | # Path to the file list.
5 | is_val: True
6 | out_size: 512
7 | crop_type: center
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 2
20 | shuffle: false
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/derain_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Derain
3 | params:
4 | # Path to the file list.
5 | is_val: False
6 | out_size: 512
7 | crop_type: random
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 4
20 | shuffle: true
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/derain_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Derain
3 | params:
4 | # Path to the file list.
5 | is_val: True
6 | out_size: 512
7 | crop_type: center
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 2
20 | shuffle: false
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/enlight_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Enlight
3 | params:
4 | # Path to the file list.
5 | is_val: False
6 | out_size: 512
7 | crop_type: random
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 8
20 | shuffle: true
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/enlight_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Enlight
3 | params:
4 | # Path to the file list.
5 | is_val: True
6 | out_size: 512
7 | crop_type: center
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 2
20 | shuffle: false
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/face_test_inp.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset_Mask
3 | params:
4 | # Path to the file list.
5 | #file_list: /home/user001/zwl/data/FFHQ512/train.list
6 | file_list: /home/user001/zwl/zyx/Diffbir/10955/meta/inp.list
7 | # wider :34.391226832176 lfw:38.66
8 | out_size: 512
9 | crop_type: none
10 | use_hflip: False
11 |
12 | blur_kernel_size: 41
13 | kernel_list: ['iso', 'aniso']
14 | kernel_prob: [0.5, 0.5]
15 | blur_sigma: [0.1, 10]
16 | downsample_range: [0.8, 8]
17 | noise_range: [0, 20]
18 | jpeg_range: [60, 100]
19 |
20 | data_loader:
21 | batch_size: 4
22 | shuffle: false
23 | num_workers: 16
24 | drop_last: true
25 |
26 | batch_transform:
27 | target: dataset.batch_transform.IdentityBatchTransform
28 |
--------------------------------------------------------------------------------
/configs/dataset/face_test_lq.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDatasetLQ
3 | params:
4 | # Path to the file list.
5 | # file_list: /data1/zyx/FFHQ512/test.list
6 | hq_list: /home/user001/zwl/zyx/Diffbir/test_list/1024.list
7 | lq_list: /home/user001/zwl/zyx/Diffbir/test_list/1024.list
8 | #lq_list: /home/user001/zwl/zyx/Diffbir/test_list/inp-z1.list
9 | #hq_list: /home/user001/zwl/data/CelebChild-Test/child.list
10 | #lq_list: /home/user001/zwl/data/CelebChild-Test/child.list
11 | # wider :34.391226832176 lfw:38.66
12 | out_size: 512
13 | crop_type: none
14 | use_hflip: False
15 |
16 | blur_kernel_size: 41
17 | kernel_list: ['iso', 'aniso']
18 | kernel_prob: [0.5, 0.5]
19 | blur_sigma: [0.1, 10]
20 | downsample_range: [0.8, 8]
21 | noise_range: [0, 20]
22 | jpeg_range: [60, 100]
23 |
24 | data_loader:
25 | batch_size: 2
26 | shuffle: false
27 | num_workers: 16
28 | drop_last: true
29 |
30 | batch_transform:
31 | target: dataset.batch_transform.IdentityBatchTransform
32 |
--------------------------------------------------------------------------------
/configs/dataset/face_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset
3 | params:
4 | # Path to the file list.
5 | file_list: /home/user001/zwl/data/FFHQ512/train.list
6 | out_size: 512
7 | crop_type: none
8 | use_hflip: true
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 10]
14 | downsample_range: [0.8, 8]
15 | noise_range: [0, 20]
16 | jpeg_range: [60, 100]
17 |
18 | # color_jitter_prob: 0.3
19 | # color_jitter_shift: 20
20 | # color_jitter_pt_prob: 0.3
21 | # gray_prob: 0.01
22 |
23 | data_loader:
24 | batch_size: 4
25 | shuffle: true
26 | num_workers: 16
27 | drop_last: true
28 |
29 | batch_transform:
30 | target: dataset.batch_transform.IdentityBatchTransform
31 |
--------------------------------------------------------------------------------
/configs/dataset/face_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset
3 | params:
4 | # Path to the file list.
5 | file_list: /home/user001/zwl/data/FFHQ512/val.list
6 | #file_list: /data1/zyx/celeba_512_validation_list_files/val_lq.list
7 |
8 | out_size: 512
9 | crop_type: none
10 | use_hflip: False
11 |
12 | blur_kernel_size: 41
13 | kernel_list: ['iso', 'aniso']
14 | kernel_prob: [0.5, 0.5]
15 | blur_sigma: [0.1, 10]
16 | downsample_range: [0.8, 8]
17 | noise_range: [0, 20]
18 | jpeg_range: [60, 100]
19 |
20 | # color_jitter_prob: 0.3
21 | # color_jitter_shift: 20
22 | # color_jitter_pt_prob: 0.3
23 | # gray_prob: 0.01
24 |
25 |
26 | data_loader:
27 | batch_size: 3
28 | shuffle: false
29 | num_workers: 16
30 | drop_last: true
31 |
32 | batch_transform:
33 | target: dataset.batch_transform.IdentityBatchTransform
34 |
--------------------------------------------------------------------------------
/configs/dataset/general_deg_codeformer_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset
3 | params:
4 | # Path to the file list.
5 | file_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/train_all.list
6 | out_size: 512
7 | crop_type: center
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 4
20 | shuffle: true
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/general_deg_codeformer_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDataset
3 | params:
4 | # Path to the file list.
5 | file_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/val.list
6 | out_size: 512
7 | crop_type: center
8 | use_hflip: True
9 |
10 | blur_kernel_size: 41
11 | kernel_list: ['iso', 'aniso']
12 | kernel_prob: [0.5, 0.5]
13 | blur_sigma: [0.1, 12]
14 | downsample_range: [1, 12]
15 | noise_range: [0, 15]
16 | jpeg_range: [30, 100]
17 |
18 | data_loader:
19 | batch_size: 2
20 | shuffle: false
21 | num_workers: 16
22 | drop_last: true
23 |
24 | batch_transform:
25 | target: dataset.batch_transform.IdentityBatchTransform
26 |
--------------------------------------------------------------------------------
/configs/dataset/general_deg_realesrgan_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.realesrgan.RealESRGANDataset
3 | params:
4 | # Path to the file list.
5 | file_list:
6 | out_size: 512
7 | crop_type: center
8 |
9 | use_hflip: false
10 | use_rot: false
11 |
12 | blur_kernel_size: 21
13 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
14 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
15 | sinc_prob: 0.1
16 | blur_sigma: [0.2, 3]
17 | betag_range: [0.5, 4]
18 | betap_range: [1, 2]
19 |
20 | blur_kernel_size2: 21
21 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
22 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
23 | sinc_prob2: 0.1
24 | blur_sigma2: [0.2, 1.5]
25 | betag_range2: [0.5, 4]
26 | betap_range2: [1, 2]
27 |
28 | final_sinc_prob: 0.8
29 |
30 | data_loader:
31 | batch_size: 32
32 | shuffle: true
33 | num_workers: 16
34 | prefetch_factor: 2
35 | drop_last: true
36 |
37 | batch_transform:
38 | target: dataset.batch_transform.RealESRGANBatchTransform
39 | params:
40 | use_sharpener: false
41 | resize_hq: false
42 | # Queue size of training pool, this should be multiples of batch_size.
43 | queue_size: 256
44 | # the first degradation process
45 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
46 | resize_range: [0.15, 1.5]
47 | gaussian_noise_prob: 0.5
48 | noise_range: [1, 30]
49 | poisson_scale_range: [0.05, 3]
50 | gray_noise_prob: 0.4
51 | jpeg_range: [30, 95]
52 |
53 | # the second degradation process
54 | stage2_scale: 4
55 | second_blur_prob: 0.8
56 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
57 | resize_range2: [0.3, 1.2]
58 | gaussian_noise_prob2: 0.5
59 | noise_range2: [1, 25]
60 | poisson_scale_range2: [0.05, 2.5]
61 | gray_noise_prob2: 0.4
62 | jpeg_range2: [30, 95]
63 |
--------------------------------------------------------------------------------
/configs/dataset/general_deg_realesrgan_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.realesrgan.RealESRGANDataset
3 | params:
4 | # Path to the file list.
5 | file_list:
6 | out_size: 512
7 | crop_type: center
8 |
9 | use_hflip: false
10 | use_rot: false
11 |
12 | blur_kernel_size: 21
13 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
14 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
15 | sinc_prob: 0.1
16 | blur_sigma: [0.2, 3]
17 | betag_range: [0.5, 4]
18 | betap_range: [1, 2]
19 |
20 | blur_kernel_size2: 21
21 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
22 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
23 | sinc_prob2: 0.1
24 | blur_sigma2: [0.2, 1.5]
25 | betag_range2: [0.5, 4]
26 | betap_range2: [1, 2]
27 |
28 | final_sinc_prob: 0.8
29 |
30 | data_loader:
31 | batch_size: 32
32 | shuffle: false
33 | num_workers: 16
34 | prefetch_factor: 2
35 | drop_last: true
36 |
37 | batch_transform:
38 | target: dataset.batch_transform.RealESRGANBatchTransform
39 | params:
40 | use_sharpener: false
41 | resize_hq: false
42 | # Queue size of training pool, this should be multiples of batch_size.
43 | queue_size: 256
44 | # the first degradation process
45 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
46 | resize_range: [0.15, 1.5]
47 | gaussian_noise_prob: 0.5
48 | noise_range: [1, 30]
49 | poisson_scale_range: [0.05, 3]
50 | gray_noise_prob: 0.4
51 | jpeg_range: [30, 95]
52 |
53 | # the second degradation process
54 | stage2_scale: 4
55 | second_blur_prob: 0.8
56 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
57 | resize_range2: [0.3, 1.2]
58 | gaussian_noise_prob2: 0.5
59 | noise_range2: [1, 25]
60 | poisson_scale_range2: [0.05, 2.5]
61 | gray_noise_prob2: 0.4
62 | jpeg_range2: [30, 95]
63 |
--------------------------------------------------------------------------------
/configs/dataset/general_test_lq.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDatasetLQ
3 | params:
4 | # Path to the file list.
5 | # file_list: /data1/zyx/FFHQ512/test.list
6 | #lq_list: /home/user001/zwl/zyx/Diffbir/test_list/custom-1.list
7 | lq_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/custom.list
8 | #hq_list: /home/user001/zwl/zyx/Diffbir/test_list/custom-1.list
9 | hq_list: /home/user001/zwl/zyx/Diffbir/imagenet_list/custom.list
10 | # wider :34.391226832176 lfw:38.66
11 | out_size: 512
12 | crop_type: center
13 | use_hflip: False
14 |
15 | blur_kernel_size: 41
16 | kernel_list: ['iso', 'aniso']
17 | kernel_prob: [0.5, 0.5]
18 | blur_sigma: [0.1, 10]
19 | downsample_range: [0.8, 8]
20 | noise_range: [0, 20]
21 | jpeg_range: [60, 100]
22 |
23 | data_loader:
24 | batch_size: 4
25 | shuffle: false
26 | num_workers: 16
27 | drop_last: true
28 |
29 | batch_transform:
30 | target: dataset.batch_transform.IdentityBatchTransform
31 |
--------------------------------------------------------------------------------
/configs/dataset/ocr_train.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDatasetLQ
3 | params:
4 | # Path to the file list.
5 | hq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/train_hq.list
6 | lq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/train_lq.list
7 | out_size: 512
8 | crop_type: random
9 | use_hflip: True
10 |
11 | blur_kernel_size: 41
12 | kernel_list: ['iso', 'aniso']
13 | kernel_prob: [0.5, 0.5]
14 | blur_sigma: [0.1, 12]
15 | downsample_range: [1, 12]
16 | noise_range: [0, 15]
17 | jpeg_range: [30, 100]
18 |
19 | data_loader:
20 | batch_size: 4
21 | shuffle: true
22 | num_workers: 16
23 | drop_last: true
24 |
25 | batch_transform:
26 | target: dataset.batch_transform.IdentityBatchTransform
27 |
--------------------------------------------------------------------------------
/configs/dataset/ocr_val.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | target: dataset.codeformer.CodeformerDatasetLQ
3 | params:
4 | # Path to the file list.
5 | hq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/val_hq.list
6 | lq_list: /home/user001/zwl/zyx/Diffbir/ocr_list/val_lq.list
7 | out_size: 512
8 | crop_type: center
9 | use_hflip: True
10 |
11 | blur_kernel_size: 41
12 | kernel_list: ['iso', 'aniso']
13 | kernel_prob: [0.5, 0.5]
14 | blur_sigma: [0.1, 12]
15 | downsample_range: [1, 12]
16 | noise_range: [0, 15]
17 | jpeg_range: [30, 100]
18 |
19 | data_loader:
20 | batch_size: 2
21 | shuffle: false
22 | num_workers: 16
23 | drop_last: true
24 |
25 | batch_transform:
26 | target: dataset.batch_transform.IdentityBatchTransform
27 |
--------------------------------------------------------------------------------
/configs/model/cldm.yaml:
--------------------------------------------------------------------------------
1 | #target: model.cldm.ControlLDM
2 | target: model.cldm.Reflow_ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | sd_locked: True
21 | only_mid_control: False
22 | # Learning rate.
23 | learning_rate: 1e-4
24 | lora_rank: 4
25 |
26 | control_stage_config:
27 | target: model.cldm.ControlNet
28 | params:
29 | use_checkpoint: True
30 | image_size: 32 # unused
31 | in_channels: 4
32 | hint_channels: 4
33 | model_channels: 320
34 | attention_resolutions: [ 4, 2, 1 ]
35 | num_res_blocks: 2
36 | channel_mult: [ 1, 2, 4, 4 ]
37 | num_head_channels: 64 # need to fix for flash-attn
38 | use_spatial_transformer: True
39 | use_linear_in_transformer: True
40 | transformer_depth: 1
41 | context_dim: 1024
42 | legacy: False
43 |
44 | unet_config:
45 | target: model.cldm.ControlledUnetModel
46 | params:
47 | use_checkpoint: false
48 | image_size: 32 # unused
49 | in_channels: 4
50 | out_channels: 4
51 | model_channels: 320
52 | attention_resolutions: [ 4, 2, 1 ]
53 | num_res_blocks: 2
54 | channel_mult: [ 1, 2, 4, 4 ]
55 | num_head_channels: 64 # need to fix for flash-attn
56 | use_spatial_transformer: True
57 | use_linear_in_transformer: True
58 | transformer_depth: 1
59 | context_dim: 1024
60 | legacy: False
61 |
62 | first_stage_config:
63 | use_fp16: False
64 | target: ldm.models.autoencoder.AutoencoderKL
65 | params:
66 | embed_dim: 4
67 | monitor: val/rec_loss
68 | ddconfig:
69 | #attn_type: "vanilla-xformers"
70 | double_z: true
71 | z_channels: 4
72 | resolution: 256
73 | in_channels: 3
74 | out_ch: 3
75 | ch: 128
76 | ch_mult:
77 | - 1
78 | - 2
79 | - 4
80 | - 4
81 | num_res_blocks: 2
82 | attn_resolutions: []
83 | dropout: 0.0
84 | lossconfig:
85 | target: torch.nn.Identity
86 |
87 | cond_stage_config:
88 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
89 | params:
90 | freeze: True
91 | layer: "penultimate"
92 |
93 | preprocess_config:
94 | target: model.swinir.SwinIR
95 | params:
96 | img_size: 64
97 | patch_size: 1
98 | in_chans: 3
99 | embed_dim: 180
100 | depths: [6, 6, 6, 6, 6, 6, 6, 6]
101 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
102 | window_size: 8
103 | mlp_ratio: 2
104 | sf: 8
105 | img_range: 1.0
106 | upsampler: "nearest+conv"
107 | resi_connection: "1conv"
108 | unshuffle: True
109 | unshuffle_scale: 8
110 |
--------------------------------------------------------------------------------
/configs/model/cldm_bsr_eval.yaml:
--------------------------------------------------------------------------------
1 | #target: model.cldm.ControlLDM
2 | target: model.cldm_bsr.Reflow_ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | sd_locked: False
21 | only_mid_control: False
22 | # Learning rate.
23 | learning_rate: 1e-4
24 | lora_rank: 64
25 |
26 | control_stage_config:
27 | target: model.cldm_bsr.ControlNet
28 | params:
29 | use_checkpoint: True
30 | image_size: 32 # unused
31 | in_channels: 4
32 | hint_channels: 4
33 | model_channels: 320
34 | attention_resolutions: [ 4, 2, 1 ]
35 | num_res_blocks: 2
36 | channel_mult: [ 1, 2, 4, 4 ]
37 | num_head_channels: 64 # need to fix for flash-attn
38 | use_spatial_transformer: True
39 | use_linear_in_transformer: True
40 | transformer_depth: 1
41 | context_dim: 1024
42 | legacy: False
43 | use_fp16: False
44 |
45 | unet_config:
46 | target: model.cldm_bsr.ControlledUnetModel
47 | params:
48 | use_checkpoint: false
49 | image_size: 32 # unused
50 | in_channels: 4
51 | out_channels: 4
52 | model_channels: 320
53 | attention_resolutions: [ 4, 2, 1 ]
54 | num_res_blocks: 2
55 | channel_mult: [ 1, 2, 4, 4 ]
56 | num_head_channels: 64 # need to fix for flash-attn
57 | use_spatial_transformer: True
58 | use_linear_in_transformer: True
59 | transformer_depth: 1
60 | context_dim: 1024
61 | legacy: False
62 | use_fp16: False
63 |
64 | first_stage_config:
65 | use_fp16: True
66 | model_id: stabilityai/sd-vae-ft-ema
67 | target: ldm.models.autoencoder.AutoencoderKL
68 | params:
69 | embed_dim: 4
70 | monitor: val/rec_loss
71 | ddconfig:
72 | #attn_type: "vanilla-xformers"
73 | double_z: true
74 | z_channels: 4
75 | resolution: 256
76 | in_channels: 3
77 | out_ch: 3
78 | ch: 128
79 | ch_mult:
80 | - 1
81 | - 2
82 | - 4
83 | - 4
84 | num_res_blocks: 2
85 | attn_resolutions: []
86 | dropout: 0.0
87 | lossconfig:
88 | target: torch.nn.Identity
89 |
90 | cond_stage_config:
91 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
92 | params:
93 | freeze: True
94 | layer: "penultimate"
95 |
96 | preprocess_config:
97 | target: model.swinir.SwinIR
98 | params:
99 | img_size: 64
100 | patch_size: 1
101 | in_chans: 3
102 | embed_dim: 180
103 | depths: [6, 6, 6, 6, 6, 6, 6, 6]
104 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
105 | window_size: 8
106 | mlp_ratio: 2
107 | sf: 8
108 | img_range: 1.0
109 | upsampler: "nearest+conv"
110 | resi_connection: "1conv"
111 | unshuffle: True
112 | unshuffle_scale: 8
113 |
--------------------------------------------------------------------------------
/configs/model/cldm_eval.yaml:
--------------------------------------------------------------------------------
1 | #target: model.cldm.ControlLDM
2 | target: model.cldm.Reflow_ControlLDM
3 | params:
4 | linear_start: 0.00085
5 | linear_end: 0.0120
6 | num_timesteps_cond: 1
7 | log_every_t: 200
8 | timesteps: 1000
9 | first_stage_key: "jpg"
10 | cond_stage_key: "txt"
11 | control_key: "hint"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | sd_locked: True
21 | only_mid_control: False
22 | # Learning rate.
23 | learning_rate: 1e-4
24 | lora_rank: 4
25 |
26 | control_stage_config:
27 | target: model.cldm.ControlNet
28 | params:
29 | use_checkpoint: True
30 | image_size: 32 # unused
31 | in_channels: 4
32 | hint_channels: 4
33 | model_channels: 320
34 | attention_resolutions: [ 4, 2, 1 ]
35 | num_res_blocks: 2
36 | channel_mult: [ 1, 2, 4, 4 ]
37 | num_head_channels: 64 # need to fix for flash-attn
38 | use_spatial_transformer: True
39 | use_linear_in_transformer: True
40 | transformer_depth: 1
41 | context_dim: 1024
42 | legacy: False
43 |
44 | unet_config:
45 | target: model.cldm.ControlledUnetModel
46 | params:
47 | use_checkpoint: false
48 | image_size: 32 # unused
49 | in_channels: 4
50 | out_channels: 4
51 | model_channels: 320
52 | attention_resolutions: [ 4, 2, 1 ]
53 | num_res_blocks: 2
54 | channel_mult: [ 1, 2, 4, 4 ]
55 | num_head_channels: 64 # need to fix for flash-attn
56 | use_spatial_transformer: True
57 | use_linear_in_transformer: True
58 | transformer_depth: 1
59 | context_dim: 1024
60 | legacy: False
61 |
62 | first_stage_config:
63 | target: ldm.models.autoencoder.AutoencoderKL
64 | params:
65 | embed_dim: 4
66 | monitor: val/rec_loss
67 | ddconfig:
68 | #attn_type: "vanilla-xformers"
69 | double_z: true
70 | z_channels: 4
71 | resolution: 256
72 | in_channels: 3
73 | out_ch: 3
74 | ch: 128
75 | ch_mult:
76 | - 1
77 | - 2
78 | - 4
79 | - 4
80 | num_res_blocks: 2
81 | attn_resolutions: []
82 | dropout: 0.0
83 | lossconfig:
84 | target: torch.nn.Identity
85 |
86 | cond_stage_config:
87 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
88 | params:
89 | freeze: True
90 | layer: "penultimate"
91 |
92 | preprocess_config:
93 | target: model.swinir.SwinIR
94 | params:
95 | img_size: 64
96 | patch_size: 1
97 | in_chans: 3
98 | embed_dim: 180
99 | depths: [6, 6, 6, 6, 6, 6, 6, 6]
100 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
101 | window_size: 8
102 | mlp_ratio: 2
103 | sf: 8
104 | img_range: 1.0
105 | upsampler: "nearest+conv"
106 | resi_connection: "1conv"
107 | unshuffle: True
108 | unshuffle_scale: 8
109 |
--------------------------------------------------------------------------------
/configs/model/swinir.yaml:
--------------------------------------------------------------------------------
1 | target: model.swinir.SwinIR
2 | params:
3 | img_size: 64
4 | patch_size: 1
5 | in_chans: 3
6 | embed_dim: 180
7 | depths: [6, 6, 6, 6, 6, 6, 6, 6]
8 | num_heads: [6, 6, 6, 6, 6, 6, 6, 6]
9 | window_size: 8
10 | mlp_ratio: 2
11 | sf: 8
12 | img_range: 1.0
13 | upsampler: "nearest+conv"
14 | resi_connection: "1conv"
15 | unshuffle: True
16 | unshuffle_scale: 8
17 |
18 | hq_key: jpg
19 | lq_key: hint
20 | # Learning rate.
21 | learning_rate: 1e-4
22 | weight_decay: 0
23 |
--------------------------------------------------------------------------------
/configs/test_cldm.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | target: dataset.data_module.BIRDataModule
3 | params:
4 | # Path to training set configuration file.
5 | train_config: ./configs/dataset/face_train.yaml
6 | # Path to validation set configuration file.
7 | val_config: ./configs/dataset/face_test_inp.yaml
8 |
9 | model:
10 | # You can set learning rate in the following configuration file.
11 | config: configs/model/cldm_eval.yaml
12 | # Path to the checkpoints or weights you want to resume. At the begining,
13 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py.
14 |
15 | resume: ./work_dirs/inpainting4/lightning_logs/version_0/checkpoints/step=79999.ckpt
16 | lightning:
17 | seed: 231
18 |
19 | trainer:
20 | accelerator: ddp
21 | precision: 32
22 | # Indices of GPUs used for training.
23 | gpus: [0]
24 | # Path to save logs and checkpoints.
25 | default_root_dir: ./work_dirs/testcelebamaskhq_1
26 | # Max number of training steps (batches).
27 | max_steps: 250001
28 | # Validation frequency in terms of training steps.
29 | val_check_interval: 500
30 | log_every_n_steps: 50
31 | # Accumulate gradients from multiple batches so as to increase batch size.
32 | accumulate_grad_batches: 1
33 |
34 | callbacks:
35 | - target: model.callbacks.ImageLogger
36 | params:
37 | # Log frequency of image logger.
38 | log_every_n_steps: 250
39 | max_images_each_step: 4
40 | log_images_kwargs: ~
41 |
42 | - target: model.callbacks.ModelCheckpoint
43 | params:
44 | # Frequency of saving checkpoints.
45 | every_n_train_steps: 5000
46 | save_top_k: -1
47 | filename: "{step}"
48 |
--------------------------------------------------------------------------------
/configs/test_cldm_general.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | target: dataset.data_module.BIRDataModule
3 | params:
4 | # Path to training set configuration file.
5 | train_config: ./configs/dataset/derain_train.yaml
6 |
7 | # Path to validation set configuration file.
8 | val_config: ./configs/dataset/derain_val.yaml
9 |
10 | model:
11 | # You can set learning rate in the following configuration file.
12 | config: configs/model/cldm_eval.yaml
13 | # Path to the checkpoints or weights you want to resume. At the begining,
14 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py.
15 | resume: ./work_dirs/exp1/lightning_logs/version_1/checkpoints/step=19999.ckpt
16 |
17 | lightning:
18 | seed: 231
19 |
20 | trainer:
21 | accelerator: ddp
22 | precision: 32
23 | # Indices of GPUs used for training.
24 | gpus: [0]
25 | # Path to save logs and checkpoints.
26 | default_root_dir: ./work_dirs/
27 | # Max number of training steps (batches).
28 | max_steps: 250001
29 | # Validation frequency in terms of training steps.
30 | val_check_interval: 500
31 | log_every_n_steps: 50
32 | # Accumulate gradients from multiple batches so as to increase batch size.
33 | accumulate_grad_batches: 1
34 |
35 | callbacks:
36 | - target: model.callbacks.ImageLogger
37 | params:
38 | # Log frequency of image logger.
39 | log_every_n_steps: 250
40 | max_images_each_step: 4
41 | log_images_kwargs: ~
42 |
43 | - target: model.callbacks.ModelCheckpoint
44 | params:
45 | # Frequency of saving checkpoints.
46 | every_n_train_steps: 5000
47 | save_top_k: -1
48 | filename: "{step}"
49 |
--------------------------------------------------------------------------------
/configs/train_cldm.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | target: dataset.data_module.BIRDataModule
3 | params:
4 | # Path to training set configuration file.
5 | train_config:
6 | # Path to validation set configuration file.
7 | val_config:
8 |
9 | model:
10 | # You can set learning rate in the following configuration file.
11 | config: configs/model/cldm.yaml
12 | # Path to the checkpoints or weights you want to resume. At the begining,
13 | # this should be set to the initial weights created by scripts/make_stage2_init_weight.py.
14 | resume: ./weights/v2-1_512-ema-pruned.ckpt
15 |
16 | lightning:
17 | seed: 231
18 |
19 | trainer:
20 | accelerator: ddp
21 | precision: 32
22 | # Indices of GPUs used for training.
23 | gpus: [1,2,3,4,5,6,7]
24 | # Path to save logs and checkpoints.
25 | default_root_dir: ./work_dirs/exp1
26 | # Max number of training steps (batches).
27 | max_steps: 250001
28 | # Validation frequency in terms of training steps.
29 | val_check_interval: 500
30 | log_every_n_steps: 50
31 | # Accumulate gradients from multiple batches so as to increase batch size.
32 | accumulate_grad_batches: 1
33 |
34 | callbacks:
35 | - target: model.callbacks.ImageLogger
36 | params:
37 | # Log frequency of image logger.
38 | log_every_n_steps: 250
39 | max_images_each_step: 4
40 | log_images_kwargs: ~
41 |
42 | - target: model.callbacks.ModelCheckpoint
43 | params:
44 | # Frequency of saving checkpoints.
45 | every_n_train_steps: 3000
46 | save_top_k: -1
47 | filename: "{step}"
48 |
--------------------------------------------------------------------------------
/configs/train_cldm_general.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | target: dataset.data_module.BIRDataModule
3 | params:
4 | # Path to training set configuration file.
5 | train_config: ./configs/dataset/derain_train.yaml
6 | # Path to validation set configuration file.
7 | val_config: ./configs/dataset/derain_val.yaml
8 |
9 | model:
10 | # You can set learning rate in the following configuration file.
11 | config: ./configs/model/cldm.yaml
12 | resume: ./work_dirs/exp0_general_all/lightning_logs/version_3/checkpoints/step=84999.ckpt
13 |
14 | lightning:
15 | seed: 231
16 |
17 | trainer:
18 | accelerator: ddp
19 | precision: 32
20 | # Indices of GPUs used for training.
21 | gpus: [1,2,3,4,5,6,7]
22 | # Path to save logs and checkpoints.
23 | default_root_dir: ./work_dirs/derain0
24 | # Max number of training steps (batches).
25 | max_steps: 250001
26 | # Validation frequency in terms of training steps.
27 | val_check_interval: 100
28 | log_every_n_steps: 250
29 | # Accumulate gradients from multiple batches so as to increase batch size.
30 | accumulate_grad_batches: 1
31 |
32 | callbacks:
33 | - target: model.callbacks.ImageLogger
34 | params:
35 | # Log frequency of image logger.
36 | log_every_n_steps: 50
37 | max_images_each_step: 4
38 | log_images_kwargs: ~
39 |
40 | - target: model.callbacks.ModelCheckpoint
41 | params:
42 | # Frequency of saving checkpoints.
43 | every_n_train_steps: 5000
44 | save_top_k: -1
45 | filename: "{step}"
46 |
--------------------------------------------------------------------------------
/configs/train_swinir.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | target: dataset.data_module.BIRDataModule
3 | params:
4 | # Path to training set configuration file.
5 | train_config: ./configs/dataset/dehaze_train.yaml
6 | # Path to validation set configuration file.
7 | val_config: ./configs/dataset/dehaze_val.yaml
8 |
9 | model:
10 | # You can set learning rate in the following configuration file.
11 | config: configs/model/swinir.yaml
12 | # Path to the checkpoints or weights you want to resume.
13 | resume: ./weights/general_swinir_v1.ckpt
14 |
15 | lightning:
16 | seed: 231
17 |
18 | trainer:
19 | accelerator: ddp
20 | precision: 32
21 | # Indices of GPUs used for training.
22 | gpus: [1,2,3,5,6,7]
23 | # Path to save logs and checkpoints.
24 | default_root_dir: ./work_dirs/swin
25 | # Max number of training steps (batches).
26 | max_steps: 150001
27 | # Validation frequency in terms of training steps.
28 | val_check_interval: 250
29 | # Log frequency of tensorboard logger.
30 | log_every_n_steps: 250
31 | # Accumulate gradients from multiple batches so as to increase batch size.
32 | accumulate_grad_batches: 1
33 |
34 | callbacks:
35 | - target: model.callbacks.ImageLogger
36 | params:
37 | # Log frequency of image logger.
38 | log_every_n_steps: 1000
39 | max_images_each_step: 4
40 | log_images_kwargs: ~
41 |
42 | - target: model.callbacks.ModelCheckpoint
43 | params:
44 | # Frequency of saving checkpoints.
45 | every_n_train_steps: 10000
46 | save_top_k: -1
47 | filename: "{step}"
48 |
--------------------------------------------------------------------------------
/dataset/__pycache__/batch_transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/batch_transform.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/codeformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/codeformer.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/data_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/dataset/__pycache__/data_module.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/data_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, Mapping
2 |
3 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
4 | import pytorch_lightning as pl
5 | from torch.utils.data import DataLoader, Dataset
6 | from omegaconf import OmegaConf
7 |
8 | from utils.common import instantiate_from_config
9 | from dataset.batch_transform import BatchTransform, IdentityBatchTransform
10 | from torch.utils.data.distributed import DistributedSampler
11 |
12 |
13 | class BIRDataModule(pl.LightningDataModule):
14 |
15 | def __init__(
16 | self,
17 | train_config: str,
18 | val_config: str=None
19 | ) -> "BIRDataModule":
20 | super().__init__()
21 | self.train_config = OmegaConf.load(train_config)
22 | self.val_config = OmegaConf.load(val_config) if val_config else None
23 |
24 | def load_dataset(self, config: Mapping[str, Any]) -> Tuple[Dataset, BatchTransform]:
25 | dataset = instantiate_from_config(config["dataset"])
26 | batch_transform = (
27 | instantiate_from_config(config["batch_transform"])
28 | if config.get("batch_transform") else IdentityBatchTransform()
29 | )
30 | return dataset, batch_transform
31 |
32 | def setup(self, stage: str) -> None:
33 | if stage == "fit":
34 | self.train_dataset, self.train_batch_transform = self.load_dataset(self.train_config)
35 | if self.val_config:
36 | self.val_dataset, self.val_batch_transform = self.load_dataset(self.val_config)
37 | else:
38 | self.val_dataset, self.val_batch_transform = None, None
39 | else:
40 | raise NotImplementedError(stage)
41 |
42 | def train_dataloader(self) -> TRAIN_DATALOADERS:
43 | return DataLoader(
44 | dataset=self.train_dataset, **self.train_config["data_loader"]
45 | )
46 |
47 | def val_dataloader(self) -> EVAL_DATALOADERS:
48 | if self.val_dataset is None:
49 | return None
50 | return DataLoader(
51 | dataset=self.val_dataset, **self.val_config["data_loader"]
52 | )
53 |
54 | def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
55 | self.trainer: pl.Trainer
56 |
57 | if self.trainer.training:
58 | return self.train_batch_transform(batch)
59 | elif self.trainer.validating or self.trainer.sanity_checking:
60 | return self.val_batch_transform(batch)
61 | else:
62 | raise RuntimeError(
63 | "Trainer state: \n"
64 | f"training: {self.trainer.training}\n"
65 | f"validating: {self.trainer.validating}\n"
66 | f"testing: {self.trainer.testing}\n"
67 | f"predicting: {self.trainer.predicting}\n"
68 | f"sanity_checking: {self.trainer.sanity_checking}"
69 | )
70 |
71 | class BIRDataModuleDistributed(pl.LightningDataModule):
72 |
73 | def __init__(
74 | self,
75 | train_config: str,
76 | val_config: str=None
77 | ) -> "BIRDataModule":
78 | super().__init__()
79 | self.train_config = OmegaConf.load(train_config)
80 | self.val_config = OmegaConf.load(val_config) if val_config else None
81 |
82 | def load_dataset(self, config: Mapping[str, Any]) -> Tuple[Dataset, BatchTransform]:
83 | dataset = instantiate_from_config(config["dataset"])
84 | batch_transform = (
85 | instantiate_from_config(config["batch_transform"])
86 | if config.get("batch_transform") else IdentityBatchTransform()
87 | )
88 | return dataset, batch_transform
89 |
90 | def setup(self, stage: str) -> None:
91 | if stage == "fit":
92 | self.train_dataset, self.train_batch_transform = self.load_dataset(self.train_config)
93 | if self.val_config:
94 | self.val_dataset, self.val_batch_transform = self.load_dataset(self.val_config)
95 | else:
96 | self.val_dataset, self.val_batch_transform = None, None
97 | else:
98 | raise NotImplementedError(stage)
99 |
100 | def train_dataloader(self) -> TRAIN_DATALOADERS:
101 | return DataLoader(
102 | dataset=self.train_dataset, **self.train_config["data_loader"]
103 | )
104 |
105 | def val_dataloader(self, rank) -> EVAL_DATALOADERS:
106 | if self.val_dataset is None:
107 | return None
108 | sampler = DistributedSampler(self.val_dataset,
109 | rank=rank,shuffle=False, drop_last=False)
110 | return DataLoader(
111 | dataset=self.val_dataset, sampler=sampler, **self.val_config["data_loader"]
112 | )
113 |
114 | def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
115 | self.trainer: pl.Trainer
116 |
117 | if self.trainer.training:
118 | return self.train_batch_transform(batch)
119 | elif self.trainer.validating or self.trainer.sanity_checking:
120 | return self.val_batch_transform(batch)
121 | else:
122 | raise RuntimeError(
123 | "Trainer state: \n"
124 | f"training: {self.trainer.training}\n"
125 | f"validating: {self.trainer.validating}\n"
126 | f"testing: {self.trainer.testing}\n"
127 | f"predicting: {self.trainer.predicting}\n"
128 | f"sanity_checking: {self.trainer.sanity_checking}"
129 | )
130 |
--------------------------------------------------------------------------------
/dataset/realesrgan.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Sequence
2 | import math
3 | import random
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils import data
9 | from PIL import Image
10 |
11 | from utils.degradation import circular_lowpass_kernel, random_mixed_kernels
12 | from utils.image import augment, random_crop_arr, center_crop_arr
13 | from utils.file import load_file_list
14 |
15 |
16 | class RealESRGANDataset(data.Dataset):
17 | """
18 | # TODO: add comment
19 | """
20 |
21 | def __init__(
22 | self,
23 | file_list: str,
24 | out_size: int,
25 | crop_type: str,
26 | use_hflip: bool,
27 | use_rot: bool,
28 | # blur kernel settings of the first degradation stage
29 | blur_kernel_size: int,
30 | kernel_list: Sequence[str],
31 | kernel_prob: Sequence[float],
32 | blur_sigma: Sequence[float],
33 | betag_range: Sequence[float],
34 | betap_range: Sequence[float],
35 | sinc_prob: float,
36 | # blur kernel settings of the second degradation stage
37 | blur_kernel_size2: int,
38 | kernel_list2: Sequence[str],
39 | kernel_prob2: Sequence[float],
40 | blur_sigma2: Sequence[float],
41 | betag_range2: Sequence[float],
42 | betap_range2: Sequence[float],
43 | sinc_prob2: float,
44 | final_sinc_prob: float
45 | ) -> "RealESRGANDataset":
46 | super(RealESRGANDataset, self).__init__()
47 | self.paths = load_file_list(file_list)
48 | self.out_size = out_size
49 | self.crop_type = crop_type
50 | assert self.crop_type in ["center", "random", "none"], f"invalid crop type: {self.crop_type}"
51 |
52 | self.blur_kernel_size = blur_kernel_size
53 | self.kernel_list = kernel_list
54 | # a list for each kernel probability
55 | self.kernel_prob = kernel_prob
56 | self.blur_sigma = blur_sigma
57 | # betag used in generalized Gaussian blur kernels
58 | self.betag_range = betag_range
59 | # betap used in plateau blur kernels
60 | self.betap_range = betap_range
61 | # the probability for sinc filters
62 | self.sinc_prob = sinc_prob
63 |
64 | self.blur_kernel_size2 = blur_kernel_size2
65 | self.kernel_list2 = kernel_list2
66 | self.kernel_prob2 = kernel_prob2
67 | self.blur_sigma2 = blur_sigma2
68 | self.betag_range2 = betag_range2
69 | self.betap_range2 = betap_range2
70 | self.sinc_prob2 = sinc_prob2
71 |
72 | # a final sinc filter
73 | self.final_sinc_prob = final_sinc_prob
74 |
75 | self.use_hflip = use_hflip
76 | self.use_rot = use_rot
77 |
78 | # kernel size ranges from 7 to 21
79 | self.kernel_range = [2 * v + 1 for v in range(3, 11)]
80 | # TODO: kernel range is now hard-coded, should be in the configure file
81 | # convolving with pulse tensor brings no blurry effect
82 | self.pulse_tensor = torch.zeros(21, 21).float()
83 | self.pulse_tensor[10, 10] = 1
84 |
85 | @torch.no_grad()
86 | def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
87 | # -------------------------------- Load hq images -------------------------------- #
88 | hq_path = self.paths[index]
89 | success = False
90 | for _ in range(3):
91 | try:
92 | pil_img = Image.open(hq_path).convert("RGB")
93 | success = True
94 | break
95 | except:
96 | time.sleep(1)
97 | assert success, f"failed to load image {hq_path}"
98 |
99 | if self.crop_type == "random":
100 | pil_img = random_crop_arr(pil_img, self.out_size)
101 | elif self.crop_type == "center":
102 | pil_img = center_crop_arr(pil_img, self.out_size)
103 | # self.crop_type is "none"
104 | else:
105 | pil_img = np.array(pil_img)
106 | assert pil_img.shape[:2] == (self.out_size, self.out_size)
107 | # hwc, rgb to bgr, [0, 255] to [0, 1], float32
108 | img_hq = (pil_img[..., ::-1] / 255.0).astype(np.float32)
109 |
110 | # -------------------- Do augmentation for training: flip, rotation -------------------- #
111 | img_hq = augment(img_hq, self.use_hflip, self.use_rot)
112 |
113 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
114 | kernel_size = random.choice(self.kernel_range)
115 | if np.random.uniform() < self.sinc_prob:
116 | # this sinc filter setting is for kernels ranging from [7, 21]
117 | if kernel_size < 13:
118 | omega_c = np.random.uniform(np.pi / 3, np.pi)
119 | else:
120 | omega_c = np.random.uniform(np.pi / 5, np.pi)
121 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
122 | else:
123 | kernel = random_mixed_kernels(
124 | self.kernel_list,
125 | self.kernel_prob,
126 | kernel_size,
127 | self.blur_sigma,
128 | self.blur_sigma, [-math.pi, math.pi],
129 | self.betag_range,
130 | self.betap_range,
131 | noise_range=None
132 | )
133 | # pad kernel
134 | pad_size = (21 - kernel_size) // 2
135 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
136 |
137 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
138 | kernel_size = random.choice(self.kernel_range)
139 | if np.random.uniform() < self.sinc_prob2:
140 | if kernel_size < 13:
141 | omega_c = np.random.uniform(np.pi / 3, np.pi)
142 | else:
143 | omega_c = np.random.uniform(np.pi / 5, np.pi)
144 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
145 | else:
146 | kernel2 = random_mixed_kernels(
147 | self.kernel_list2,
148 | self.kernel_prob2,
149 | kernel_size,
150 | self.blur_sigma2,
151 | self.blur_sigma2, [-math.pi, math.pi],
152 | self.betag_range2,
153 | self.betap_range2,
154 | noise_range=None
155 | )
156 |
157 | # pad kernel
158 | pad_size = (21 - kernel_size) // 2
159 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
160 |
161 | # ------------------------------------- the final sinc kernel ------------------------------------- #
162 | if np.random.uniform() < self.final_sinc_prob:
163 | kernel_size = random.choice(self.kernel_range)
164 | omega_c = np.random.uniform(np.pi / 3, np.pi)
165 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
166 | sinc_kernel = torch.FloatTensor(sinc_kernel)
167 | else:
168 | sinc_kernel = self.pulse_tensor
169 |
170 | # [0, 1], BGR to RGB, HWC to CHW
171 | img_hq = torch.from_numpy(
172 | img_hq[..., ::-1].transpose(2, 0, 1).copy()
173 | ).float()
174 | kernel = torch.FloatTensor(kernel)
175 | kernel2 = torch.FloatTensor(kernel2)
176 |
177 | return {
178 | "hq": img_hq, "kernel1": kernel, "kernel2": kernel2,
179 | "sinc_kernel": sinc_kernel, "txt": ""
180 | }
181 |
182 | def __len__(self) -> int:
183 | return len(self.paths)
184 |
--------------------------------------------------------------------------------
/dataset/test.py:
--------------------------------------------------------------------------------
1 | import rawpy
2 | if __name__ == "__main__":
3 | with rawpy.imread("/home/user001/zwl/data/Sony/long/00001_00_10s.ARW") as raw_target:
4 | pass
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import os
4 | import cv2
5 | import torch
6 | from torchvision import transforms
7 | from scipy.ndimage import gaussian_filter
8 | from lpips_pytorch import LPIPS, lpips
9 | from skimage.metrics import peak_signal_noise_ratio as psnr
10 | import pytorch_ssim
11 | from argparse import ArgumentParser
12 | from tqdm import tqdm
13 | import pdb
14 |
15 |
16 | import pyiqa
17 | # def psnr(img1, img2):
18 | # mse = np.mean((img1 - img2) ** 2 )
19 | # if mse == 0:
20 | # return 100
21 | # return 20 * math.log10(255.0 / math.sqrt(mse))
22 | def calculate_psnr(img1, img2):
23 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
24 |
25 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
26 |
27 | Args:
28 | img1 (ndarray): Images with range [0, 255].
29 | img2 (ndarray): Images with range [0, 255].
30 | crop_border (int): Cropped pixels in each edge of an image. These
31 | pixels are not involved in the PSNR calculation.
32 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
33 | Default: 'HWC'.
34 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
35 |
36 | Returns:
37 | float: psnr result.
38 | """
39 |
40 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
41 | img1 = img1.astype(np.float64)
42 | img2 = img2.astype(np.float64)
43 | mse = np.mean((img1 - img2)**2)
44 | if mse == 0:
45 | return float('inf')
46 |
47 | #pdb.set_trace()
48 | return 10. * np.log10(255.*255.0 /mse)
49 |
50 | def _ssim(img1, img2):
51 | """Calculate SSIM (structural similarity) for one channel images.
52 |
53 | It is called by func:`calculate_ssim`.
54 |
55 | Args:
56 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
57 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
58 |
59 | Returns:
60 | float: ssim result.
61 | """
62 |
63 | C1 = (0.01 * 255)**2
64 | C2 = (0.03 * 255)**2
65 |
66 | img1 = img1.astype(np.float64)
67 | img2 = img2.astype(np.float64)
68 | kernel = cv2.getGaussianKernel(11, 1.5)
69 | window = np.outer(kernel, kernel.transpose())
70 |
71 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
72 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
73 | mu1_sq = mu1**2
74 | mu2_sq = mu2**2
75 | mu1_mu2 = mu1 * mu2
76 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
77 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
78 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
79 |
80 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
81 | return ssim_map.mean()
82 | def calculate_ssim(img1, img2):
83 | """Calculate SSIM (structural similarity).
84 |
85 | Ref:
86 | Image quality assessment: From error visibility to structural similarity
87 |
88 | The results are the same as that of the official released MATLAB code in
89 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
90 |
91 | For three-channel images, SSIM is calculated for each channel and then
92 | averaged.
93 |
94 | Args:
95 | img1 (ndarray): Images with range [0, 255].
96 | img2 (ndarray): Images with range [0, 255].
97 | crop_border (int): Cropped pixels in each edge of an image. These
98 | pixels are not involved in the SSIM calculation.
99 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
100 | Default: 'HWC'.
101 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
102 |
103 | Returns:
104 | float: ssim result.
105 | """
106 |
107 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
108 |
109 | img1 = img1.astype(np.float64)
110 | img2 = img2.astype(np.float64)
111 |
112 |
113 | ssims = []
114 | for i in range(img1.shape[2]):
115 | ssims.append(_ssim(img1[..., i], img2[..., i]))
116 | return np.array(ssims).mean()
117 |
118 | def main(path1, path2, type="all"):
119 | loss_1 = []
120 | loss_2 = []
121 | loss_3 = []
122 | loss_4 = []
123 | length = 0
124 |
125 | if type=="lpips" or type == "all":
126 | lpips = LPIPS().cuda()
127 | if type=="ssim" or type == "all":
128 | ssim = pytorch_ssim.SSIM(window_size = 11).cuda()
129 |
130 | #iqa_metric = pyiqa.create_metric(type, test_y_channel=False, color_space='ycbcr').cuda()
131 | lpips_metric = pyiqa.create_metric('lpips').cuda()
132 | ssim_metric = pyiqa.create_metric('ssim').cuda()
133 | musiq_metric = pyiqa.create_metric('musiq').cuda()
134 |
135 | for idx ,img in tqdm(enumerate(os.listdir(path1)),total=len(os.listdir(path1))):
136 | imgpath1 = os.path.join(path1,img)
137 | imgpath2 = os.path.join(path2,img)
138 | imgpath2 = imgpath2[:-3]+'png'
139 |
140 | #print(imgpath1)
141 | #img1 = cv2.imread(imgpath1).astype(np.float64)
142 |
143 |
144 | #img2 = cv2.imread(imgpath2).astype(np.float64)
145 |
146 |
147 | # mean_l = []
148 | # std_l = []
149 | # for j in range(3):
150 | # mean_l.append(np.mean(img2[:, :, j]))
151 | # std_l.append(np.std(img2[:, :, j]))
152 | # for j in range(3):
153 | # # correct twice
154 | # mean = np.mean(img1[:, :, j])
155 | # img1[:, :, j] = img1[:, :, j] - mean + mean_l[j]
156 | # std = np.std(img1[:, :, j])
157 | # img1[:, :, j] = img1[:, :, j] / std * std_l[j]
158 |
159 | # mean = np.mean(img1[:, :, j])
160 | # img1[:, :, j] = img1[:, :, j] - mean + mean_l[j]
161 | # std = np.std(img1[:, :, j])
162 | # img1[:, :, j] = img1[:, :, j] / std * std_l[j]
163 | # img1 = cv2.resize(img1,(256,256))
164 | # img2 = cv2.resize(img2,(256,256))
165 | # if img1.shape != img2.shape:
166 | # if img1.shape[0]< img2.shape[0]:
167 | # img2 = cv2.resize(img2,img1.shape[:2])
168 | # else:
169 | # img1 = cv2.resize(img1,img2.shape[:2])
170 | if type=="psnr":
171 | psnr_score = iqa_metric(imgpath1,imgpath2)
172 | # loss_1 += psnr(img1,img2,data_range=255.0)
173 | loss_1.append(psnr_score.cpu().numpy())
174 | elif type=="lpips":
175 | lpips_score = lpips_metric(imgpath1,imgpath2)
176 | loss_2.append(lpips_score.cpu().numpy())
177 | elif type=="ssim":
178 | ssim_score = ssim_metric(imgpath1,imgpath2)
179 | loss_3.append(ssim_score.cpu().numpy())
180 | elif type=="musiq":
181 | musiq_score = musiq_metric(imgpath1)
182 | loss_4.append(musiq_score.cpu().numpy())
183 | elif type == "all":
184 | loss_1 += psnr(img1,img2)
185 | loss_2 += lpips(transforms.ToTensor()(img1).cuda(),transforms.ToTensor()(img2).cuda())
186 | loss_3 += ssim(transforms.ToTensor()(img1).unsqueeze(0).cuda(),transforms.ToTensor()(img2).unsqueeze(0).cuda())
187 | # loss += criterion(transforms.ToTensor()(img1).cuda(),transforms.ToTensor()(img2).cuda())
188 | # loss += criterion(transforms.ToTensor()(img1).unsqueeze(0).cuda(),transforms.ToTensor()(img2).unsqueeze(0).cuda())
189 | length +=1
190 | if type=="psnr" or type == "all":
191 | print("psnr↑",np.mean(loss_1))
192 | if type=="lpips" or type == "all":
193 | print("lpips↓",np.mean(loss_2))
194 | if type=="ssim" or type == "all":
195 | print("ssim↑",np.mean(loss_3))
196 | if type=="musiq" or type == "all":
197 | print("musiq↑",np.mean(loss_4))
198 |
199 |
200 |
201 | if __name__ == "__main__":
202 | parser = ArgumentParser()
203 | parser.add_argument("--input1", type=str, required=True)
204 | parser.add_argument("--input2", type=str, required=True)
205 | parser.add_argument("--type", type=str, default="all")
206 | args = parser.parse_args()
207 | main(args.input1, args.input2, args.type)
208 |
209 |
210 | '''
211 | DiffBIR
212 | psnr↑ 31.14
213 | lpips↓ 0.2063
214 | ssim↑ 0.6731
215 |
216 | midd
217 | psnr↑ 30.87
218 | lpips↓ 0.2046
219 | ssim↑ 0.6719
220 |
221 | final
222 | psnr↑ 31.17
223 | lpips↓ 0.2248
224 | ssim↑ 0.7220
225 |
226 | '''
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/xformers_state.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/xformers_state.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/__pycache__/xformers_state.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/__pycache__/xformers_state.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/data/__init__.py
--------------------------------------------------------------------------------
/ldm/data/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ldm.modules.midas.api import load_midas_transform
4 |
5 |
6 | class AddMiDaS(object):
7 | def __init__(self, model_type):
8 | super().__init__()
9 | self.transform = load_midas_transform(model_type)
10 |
11 | def pt2np(self, x):
12 | x = ((x + 1.0) * .5).detach().cpu().numpy()
13 | return x
14 |
15 | def np2pt(self, x):
16 | x = torch.from_numpy(x) * 2 - 1.
17 | return x
18 |
19 | def __call__(self, sample):
20 | # sample['jpg'] is tensor hwc in [-1, 1] at this point
21 | x = self.pt2np(sample['jpg'])
22 | x = self.transform({"image": x})["image"]
23 | sample['midas_in'] = x
24 | return sample
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/__pycache__/autoencoder.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/__pycache__/autoencoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/__pycache__/autoencoder.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 | import torch
3 |
4 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5 |
6 |
7 | MODEL_TYPES = {
8 | "eps": "noise",
9 | "v": "v"
10 | }
11 |
12 |
13 | class DPMSolverSampler(object):
14 | def __init__(self, model, **kwargs):
15 | super().__init__()
16 | self.model = model
17 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19 |
20 | def register_buffer(self, name, attr):
21 | if type(attr) == torch.Tensor:
22 | if attr.device != torch.device("cuda"):
23 | attr = attr.to(torch.device("cuda"))
24 | setattr(self, name, attr)
25 |
26 | @torch.no_grad()
27 | def sample(self,
28 | S,
29 | batch_size,
30 | shape,
31 | conditioning=None,
32 | callback=None,
33 | normals_sequence=None,
34 | img_callback=None,
35 | quantize_x0=False,
36 | eta=0.,
37 | mask=None,
38 | x0=None,
39 | temperature=1.,
40 | noise_dropout=0.,
41 | score_corrector=None,
42 | corrector_kwargs=None,
43 | verbose=True,
44 | x_T=None,
45 | log_every_t=100,
46 | unconditional_guidance_scale=1.,
47 | unconditional_conditioning=None,
48 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49 | **kwargs
50 | ):
51 | if conditioning is not None:
52 | if isinstance(conditioning, dict):
53 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54 | if cbs != batch_size:
55 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56 | else:
57 | if conditioning.shape[0] != batch_size:
58 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59 |
60 | # sampling
61 | C, H, W = shape
62 | size = (batch_size, C, H, W)
63 |
64 | print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65 |
66 | device = self.model.betas.device
67 | if x_T is None:
68 | img = torch.randn(size, device=device)
69 | else:
70 | img = x_T
71 |
72 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73 |
74 | model_fn = model_wrapper(
75 | lambda x, t, c: self.model.apply_model(x, t, c),
76 | ns,
77 | model_type=MODEL_TYPES[self.model.parameterization],
78 | guidance_type="classifier-free",
79 | condition=conditioning,
80 | unconditional_condition=unconditional_conditioning,
81 | guidance_scale=unconditional_guidance_scale,
82 | )
83 |
84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86 |
87 | return x.to(device), None
--------------------------------------------------------------------------------
/ldm/models/diffusion/sampling_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def append_dims(x, target_dims):
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8 | dims_to_append = target_dims - x.ndim
9 | if dims_to_append < 0:
10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11 | return x[(...,) + (None,) * dims_to_append]
12 |
13 |
14 | def norm_thresholding(x0, value):
15 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16 | return x0 * (value / s)
17 |
18 |
19 | def spatial_norm_thresholding(x0, value):
20 | # b c h w
21 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22 | return x0 * (value / s)
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/attention.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/__pycache__/ema.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/__pycache__/ema.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/diffusionmodules/upscaling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from functools import partial
5 |
6 | from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7 | from ldm.util import default
8 |
9 |
10 | class AbstractLowScaleModel(nn.Module):
11 | # for concatenating a downsampled image to the latent representation
12 | def __init__(self, noise_schedule_config=None):
13 | super(AbstractLowScaleModel, self).__init__()
14 | if noise_schedule_config is not None:
15 | self.register_schedule(**noise_schedule_config)
16 |
17 | def register_schedule(self, beta_schedule="linear", timesteps=1000,
18 | linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20 | cosine_s=cosine_s)
21 | alphas = 1. - betas
22 | alphas_cumprod = np.cumprod(alphas, axis=0)
23 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24 |
25 | timesteps, = betas.shape
26 | self.num_timesteps = int(timesteps)
27 | self.linear_start = linear_start
28 | self.linear_end = linear_end
29 | assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30 |
31 | to_torch = partial(torch.tensor, dtype=torch.float32)
32 |
33 | self.register_buffer('betas', to_torch(betas))
34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36 |
37 | # calculations for diffusion q(x_t | x_{t-1}) and others
38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43 |
44 | def q_sample(self, x_start, t, noise=None):
45 | noise = default(noise, lambda: torch.randn_like(x_start))
46 | return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48 |
49 | def forward(self, x):
50 | return x, None
51 |
52 | def decode(self, x):
53 | return x
54 |
55 |
56 | class SimpleImageConcat(AbstractLowScaleModel):
57 | # no noise level conditioning
58 | def __init__(self):
59 | super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60 | self.max_noise_level = 0
61 |
62 | def forward(self, x):
63 | # fix to constant noise level
64 | return x, torch.zeros(x.shape[0], device=x.device).long()
65 |
66 |
67 | class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68 | def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69 | super().__init__(noise_schedule_config=noise_schedule_config)
70 | self.max_noise_level = max_noise_level
71 |
72 | def forward(self, x, noise_level=None):
73 | if noise_level is None:
74 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75 | else:
76 | assert isinstance(noise_level, torch.Tensor)
77 | z = self.q_sample(x, noise_level)
78 | return z, noise_level
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
--------------------------------------------------------------------------------
/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc
--------------------------------------------------------------------------------
/ldm/modules/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import torch
5 | import torch.nn as nn
6 | from torchvision.transforms import Compose
7 |
8 | from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9 | from ldm.modules.midas.midas.midas_net import MidasNet
10 | from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11 | from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12 |
13 |
14 | ISL_PATHS = {
15 | "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16 | "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17 | "midas_v21": "",
18 | "midas_v21_small": "",
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | def load_midas_transform(model_type):
29 | # https://github.com/isl-org/MiDaS/blob/master/run.py
30 | # load transform only
31 | if model_type == "dpt_large": # DPT-Large
32 | net_w, net_h = 384, 384
33 | resize_mode = "minimal"
34 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35 |
36 | elif model_type == "dpt_hybrid": # DPT-Hybrid
37 | net_w, net_h = 384, 384
38 | resize_mode = "minimal"
39 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40 |
41 | elif model_type == "midas_v21":
42 | net_w, net_h = 384, 384
43 | resize_mode = "upper_bound"
44 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45 |
46 | elif model_type == "midas_v21_small":
47 | net_w, net_h = 256, 256
48 | resize_mode = "upper_bound"
49 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 |
51 | else:
52 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53 |
54 | transform = Compose(
55 | [
56 | Resize(
57 | net_w,
58 | net_h,
59 | resize_target=None,
60 | keep_aspect_ratio=True,
61 | ensure_multiple_of=32,
62 | resize_method=resize_mode,
63 | image_interpolation_method=cv2.INTER_CUBIC,
64 | ),
65 | normalization,
66 | PrepareForNet(),
67 | ]
68 | )
69 |
70 | return transform
71 |
72 |
73 | def load_model(model_type):
74 | # https://github.com/isl-org/MiDaS/blob/master/run.py
75 | # load network
76 | model_path = ISL_PATHS[model_type]
77 | if model_type == "dpt_large": # DPT-Large
78 | model = DPTDepthModel(
79 | path=model_path,
80 | backbone="vitl16_384",
81 | non_negative=True,
82 | )
83 | net_w, net_h = 384, 384
84 | resize_mode = "minimal"
85 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86 |
87 | elif model_type == "dpt_hybrid": # DPT-Hybrid
88 | model = DPTDepthModel(
89 | path=model_path,
90 | backbone="vitb_rn50_384",
91 | non_negative=True,
92 | )
93 | net_w, net_h = 384, 384
94 | resize_mode = "minimal"
95 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96 |
97 | elif model_type == "midas_v21":
98 | model = MidasNet(model_path, non_negative=True)
99 | net_w, net_h = 384, 384
100 | resize_mode = "upper_bound"
101 | normalization = NormalizeImage(
102 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103 | )
104 |
105 | elif model_type == "midas_v21_small":
106 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107 | non_negative=True, blocks={'expand': True})
108 | net_w, net_h = 256, 256
109 | resize_mode = "upper_bound"
110 | normalization = NormalizeImage(
111 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112 | )
113 |
114 | else:
115 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
116 | assert False
117 |
118 | transform = Compose(
119 | [
120 | Resize(
121 | net_w,
122 | net_h,
123 | resize_target=None,
124 | keep_aspect_ratio=True,
125 | ensure_multiple_of=32,
126 | resize_method=resize_mode,
127 | image_interpolation_method=cv2.INTER_CUBIC,
128 | ),
129 | normalization,
130 | PrepareForNet(),
131 | ]
132 | )
133 |
134 | return model.eval(), transform
135 |
136 |
137 | class MiDaSInference(nn.Module):
138 | MODEL_TYPES_TORCH_HUB = [
139 | "DPT_Large",
140 | "DPT_Hybrid",
141 | "MiDaS_small"
142 | ]
143 | MODEL_TYPES_ISL = [
144 | "dpt_large",
145 | "dpt_hybrid",
146 | "midas_v21",
147 | "midas_v21_small",
148 | ]
149 |
150 | def __init__(self, model_type):
151 | super().__init__()
152 | assert (model_type in self.MODEL_TYPES_ISL)
153 | model, _ = load_model(model_type)
154 | self.model = model
155 | self.model.train = disabled_train
156 |
157 | def forward(self, x):
158 | # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159 | # NOTE: we expect that the correct transform has been called during dataloading.
160 | with torch.no_grad():
161 | prediction = self.model(x)
162 | prediction = torch.nn.functional.interpolate(
163 | prediction.unsqueeze(1),
164 | size=x.shape[2:],
165 | mode="bicubic",
166 | align_corners=False,
167 | )
168 | assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169 | return prediction
170 |
171 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/ldm/modules/midas/midas/__init__.py
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/ldm/modules/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/ldm/modules/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | from torch import optim
5 | import numpy as np
6 |
7 | from inspect import isfunction
8 | from PIL import Image, ImageDraw, ImageFont
9 |
10 |
11 | def log_txt_as_img(wh, xc, size=10):
12 | # wh a tuple of (width, height)
13 | # xc a list of captions to plot
14 | b = len(xc)
15 | txts = list()
16 | for bi in range(b):
17 | txt = Image.new("RGB", wh, color="white")
18 | draw = ImageDraw.Draw(txt)
19 | # font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
20 | font = ImageFont.load_default()
21 | nc = int(40 * (wh[0] / 256))
22 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
23 |
24 | try:
25 | draw.text((0, 0), lines, fill="black", font=font)
26 | except UnicodeEncodeError:
27 | print("Cant encode string for logging. Skipping.")
28 |
29 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
30 | txts.append(txt)
31 | txts = np.stack(txts)
32 | txts = torch.tensor(txts)
33 | return txts
34 |
35 |
36 | def ismap(x):
37 | if not isinstance(x, torch.Tensor):
38 | return False
39 | return (len(x.shape) == 4) and (x.shape[1] > 3)
40 |
41 |
42 | def isimage(x):
43 | if not isinstance(x,torch.Tensor):
44 | return False
45 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
46 |
47 |
48 | def exists(x):
49 | return x is not None
50 |
51 |
52 | def default(val, d):
53 | if exists(val):
54 | return val
55 | return d() if isfunction(d) else d
56 |
57 |
58 | def mean_flat(tensor):
59 | """
60 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
61 | Take the mean over all non-batch dimensions.
62 | """
63 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
64 |
65 |
66 | def count_params(model, verbose=False):
67 | total_params = sum(p.numel() for p in model.parameters())
68 | if verbose:
69 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
70 | return total_params
71 |
72 |
73 | def instantiate_from_config(config):
74 | if not "target" in config:
75 | if config == '__is_first_stage__':
76 | return None
77 | elif config == "__is_unconditional__":
78 | return None
79 | raise KeyError("Expected key `target` to instantiate.")
80 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
81 |
82 |
83 | def get_obj_from_str(string, reload=False):
84 | module, cls = string.rsplit(".", 1)
85 | if reload:
86 | module_imp = importlib.import_module(module)
87 | importlib.reload(module_imp)
88 | return getattr(importlib.import_module(module, package=None), cls)
89 |
90 |
91 | class AdamWwithEMAandWings(optim.Optimizer):
92 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
93 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
94 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
95 | ema_power=1., param_names=()):
96 | """AdamW that saves EMA versions of the parameters."""
97 | if not 0.0 <= lr:
98 | raise ValueError("Invalid learning rate: {}".format(lr))
99 | if not 0.0 <= eps:
100 | raise ValueError("Invalid epsilon value: {}".format(eps))
101 | if not 0.0 <= betas[0] < 1.0:
102 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
103 | if not 0.0 <= betas[1] < 1.0:
104 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
105 | if not 0.0 <= weight_decay:
106 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
107 | if not 0.0 <= ema_decay <= 1.0:
108 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
109 | defaults = dict(lr=lr, betas=betas, eps=eps,
110 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
111 | ema_power=ema_power, param_names=param_names)
112 | super().__init__(params, defaults)
113 |
114 | def __setstate__(self, state):
115 | super().__setstate__(state)
116 | for group in self.param_groups:
117 | group.setdefault('amsgrad', False)
118 |
119 | @torch.no_grad()
120 | def step(self, closure=None):
121 | """Performs a single optimization step.
122 | Args:
123 | closure (callable, optional): A closure that reevaluates the model
124 | and returns the loss.
125 | """
126 | loss = None
127 | if closure is not None:
128 | with torch.enable_grad():
129 | loss = closure()
130 |
131 | for group in self.param_groups:
132 | params_with_grad = []
133 | grads = []
134 | exp_avgs = []
135 | exp_avg_sqs = []
136 | ema_params_with_grad = []
137 | state_sums = []
138 | max_exp_avg_sqs = []
139 | state_steps = []
140 | amsgrad = group['amsgrad']
141 | beta1, beta2 = group['betas']
142 | ema_decay = group['ema_decay']
143 | ema_power = group['ema_power']
144 |
145 | for p in group['params']:
146 | if p.grad is None:
147 | continue
148 | params_with_grad.append(p)
149 | if p.grad.is_sparse:
150 | raise RuntimeError('AdamW does not support sparse gradients')
151 | grads.append(p.grad)
152 |
153 | state = self.state[p]
154 |
155 | # State initialization
156 | if len(state) == 0:
157 | state['step'] = 0
158 | # Exponential moving average of gradient values
159 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
160 | # Exponential moving average of squared gradient values
161 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
162 | if amsgrad:
163 | # Maintains max of all exp. moving avg. of sq. grad. values
164 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
165 | # Exponential moving average of parameter values
166 | state['param_exp_avg'] = p.detach().float().clone()
167 |
168 | exp_avgs.append(state['exp_avg'])
169 | exp_avg_sqs.append(state['exp_avg_sq'])
170 | ema_params_with_grad.append(state['param_exp_avg'])
171 |
172 | if amsgrad:
173 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
174 |
175 | # update the steps for each param group update
176 | state['step'] += 1
177 | # record the step after step update
178 | state_steps.append(state['step'])
179 |
180 | optim._functional.adamw(params_with_grad,
181 | grads,
182 | exp_avgs,
183 | exp_avg_sqs,
184 | max_exp_avg_sqs,
185 | state_steps,
186 | amsgrad=amsgrad,
187 | beta1=beta1,
188 | beta2=beta2,
189 | lr=group['lr'],
190 | weight_decay=group['weight_decay'],
191 | eps=group['eps'],
192 | maximize=False)
193 |
194 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
195 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
196 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
197 |
198 | return loss
--------------------------------------------------------------------------------
/ldm/xformers_state.py:
--------------------------------------------------------------------------------
1 | try:
2 | import xformers
3 | import xformers.ops
4 | XFORMERS_IS_AVAILBLE = True
5 | except:
6 | XFORMERS_IS_AVAILBLE = False
7 | print("No module 'xformers'. Proceeding without it.")
8 |
9 |
10 | def is_xformers_available() -> bool:
11 | global XFORMERS_IS_AVAILBLE
12 | return XFORMERS_IS_AVAILBLE
13 |
14 | def disable_xformers() -> None:
15 | print("DISABLE XFORMERS!")
16 | global XFORMERS_IS_AVAILBLE
17 | XFORMERS_IS_AVAILBLE = False
18 |
19 | def enable_xformers() -> None:
20 | print("ENABLE XFORMERS!")
21 | global XFORMERS_IS_AVAILBLE
22 | XFORMERS_IS_AVAILBLE = True
23 |
24 | def auto_xformers_status(device):
25 | if 'cuda' in str(device):
26 | enable_xformers()
27 | elif str(device) == 'cpu':
28 | disable_xformers()
29 | else:
30 | raise ValueError(f"Unknown device {device}")
31 |
--------------------------------------------------------------------------------
/model/__pycache__/callbacks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/callbacks.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cldm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cldm.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cldm_bsr.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cldm_bsr.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/cond_fn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/cond_fn.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/mixins.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/mixins.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/mixins.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/mixins.cpython-39.pyc
--------------------------------------------------------------------------------
/model/__pycache__/spaced_sampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/spaced_sampler.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/swinir.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/swinir.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/swinir.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/model/__pycache__/swinir.cpython-39.pyc
--------------------------------------------------------------------------------
/model/callbacks.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | import os
3 |
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from pytorch_lightning.utilities.types import STEP_OUTPUT
8 | import torch
9 | import torchvision
10 | from PIL import Image
11 | from pytorch_lightning.callbacks import Callback
12 | from pytorch_lightning.utilities.distributed import rank_zero_only
13 |
14 | from .mixins import ImageLoggerMixin
15 |
16 |
17 | __all__ = [
18 | "ModelCheckpoint",
19 | "ImageLogger"
20 | ]
21 |
22 | class ImageLogger(Callback):
23 | """
24 | Log images during training or validating.
25 |
26 | TODO: Support validating.
27 | """
28 |
29 | def __init__(
30 | self,
31 | log_every_n_steps: int=2000,
32 | max_images_each_step: int=4,
33 | log_images_kwargs: Dict[str, Any]=None
34 | ) -> "ImageLogger":
35 | super().__init__()
36 | self.log_every_n_steps = log_every_n_steps
37 | self.max_images_each_step = max_images_each_step
38 | self.log_images_kwargs = log_images_kwargs or dict()
39 |
40 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
41 | assert isinstance(pl_module, ImageLoggerMixin)
42 |
43 | @rank_zero_only
44 | def on_train_batch_end(
45 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT,
46 | batch: Any, batch_idx: int, dataloader_idx: int
47 | ) -> None:
48 | if pl_module.global_step % self.log_every_n_steps == 0:
49 | is_train = pl_module.training
50 | if is_train:
51 | pl_module.freeze()
52 |
53 | with torch.no_grad():
54 | # returned images should be: nchw, rgb, [0, 1]
55 | images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
56 |
57 | # save images
58 | save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train")
59 | os.makedirs(save_dir, exist_ok=True)
60 | for image_key in images:
61 | image = images[image_key].detach().cpu()
62 | N = min(self.max_images_each_step, len(image))
63 | grid = torchvision.utils.make_grid(image[:N], nrow=4)
64 | # chw -> hwc (hw if gray)
65 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy()
66 | grid = (grid * 255).clip(0, 255).astype(np.uint8)
67 | filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format(
68 | image_key, pl_module.global_step, pl_module.current_epoch, batch_idx
69 | )
70 | path = os.path.join(save_dir, filename)
71 | Image.fromarray(grid).save(path)
72 |
73 | if is_train:
74 | pl_module.unfreeze()
75 |
--------------------------------------------------------------------------------
/model/cond_fn.py:
--------------------------------------------------------------------------------
1 | from typing import overload, Optional
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | class Guidance:
7 |
8 | def __init__(
9 | self,
10 | scale: float,
11 | t_start: int,
12 | t_stop: int,
13 | space: str,
14 | repeat: int
15 | ) -> "Guidance":
16 | """
17 | Initialize latent image guidance.
18 |
19 | Args:
20 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
21 | the closer the final result will be to the output of the first stage model.
22 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
23 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
24 | space (str): The data space for computing loss function (rgb or latent).
25 | repeat (int): Repeat gradient descent for `repeat` times.
26 |
27 | Our latent image guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
28 | Thanks for their work!
29 | """
30 | self.scale = scale
31 | self.t_start = t_start
32 | self.t_stop = t_stop
33 | self.target = None
34 | self.space = space
35 | self.repeat = repeat
36 |
37 | def load_target(self, target: torch.Tensor) -> torch.Tensor:
38 | self.target = target
39 |
40 | def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Optional[torch.Tensor]:
41 | if self.t_stop < t and t < self.t_start:
42 | # print("sampling with classifier guidance")
43 | # avoid propagating gradient out of this scope
44 | pred_x0 = pred_x0.detach().clone()
45 | target_x0 = target_x0.detach().clone()
46 | return self.scale * self._forward(target_x0, pred_x0)
47 | else:
48 | return None
49 |
50 | @overload
51 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
52 | ...
53 |
54 |
55 | class MSEGuidance(Guidance):
56 |
57 | def __init__(
58 | self,
59 | scale: float,
60 | t_start: int,
61 | t_stop: int,
62 | space: str,
63 | repeat: int
64 | ) -> "MSEGuidance":
65 | super().__init__(
66 | scale, t_start, t_stop, space, repeat
67 | )
68 |
69 | @torch.enable_grad()
70 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
71 | # inputs: [-1, 1], nchw, rgb
72 | pred_x0.requires_grad_(True)
73 |
74 | # This is what we actually use.
75 | loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
76 |
77 | print(f"loss = {loss.item()}")
78 | return -torch.autograd.grad(loss, pred_x0)[0]
79 |
--------------------------------------------------------------------------------
/model/mixins.py:
--------------------------------------------------------------------------------
1 | from typing import overload, Any, Dict
2 | import torch
3 |
4 |
5 | class ImageLoggerMixin:
6 |
7 | @overload
8 | def log_images(self, batch: Any, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
9 | ...
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu116
2 | torch==1.13.1+cu116
3 | torchvision==0.14.1+cu116
4 | xformers==0.0.16
5 | pytorch_lightning==1.4.2\
6 | einops
7 | open-clip-torch==2.24.0
8 | omegaconf
9 | torchmetrics==0.6.0
10 | triton==2.0.0
11 | lora-diffusion==0.1.7
12 | opencv-python-headless
13 | scipy
14 | matplotlib
15 | lpips
16 | gradio
17 | chardet
18 | transformers
19 | facexlib
20 |
--------------------------------------------------------------------------------
/scripts/inference_stage1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | import os
4 | from argparse import ArgumentParser, Namespace
5 |
6 | import pytorch_lightning as pl
7 | from omegaconf import OmegaConf
8 | import torch
9 | from PIL import Image
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | from utils.image import auto_resize, pad
14 | from utils.common import load_state_dict, instantiate_from_config
15 | from utils.file import list_image_files, get_file_name_parts
16 |
17 |
18 | def parse_args() -> Namespace:
19 | parser = ArgumentParser()
20 | parser.add_argument("--config", type=str, required=True)
21 | parser.add_argument("--ckpt", type=str, required=True)
22 | parser.add_argument("--input", type=str, required=True)
23 | parser.add_argument("--sr_scale", type=float, default=1)
24 | parser.add_argument("--image_size", type=int, default=512)
25 | parser.add_argument("--show_lq", action="store_true")
26 | parser.add_argument("--resize_back", action="store_true")
27 | parser.add_argument("--output", type=str, required=True)
28 | parser.add_argument("--skip_if_exist", action="store_true")
29 | parser.add_argument("--seed", type=int, default=231)
30 | return parser.parse_args()
31 |
32 |
33 | @torch.no_grad()
34 | def main():
35 | args = parse_args()
36 | pl.seed_everything(args.seed)
37 | device = "cuda" if torch.cuda.is_available() else "cpu"
38 |
39 | model: pl.LightningModule = instantiate_from_config(OmegaConf.load(args.config))
40 | load_state_dict(model, torch.load(args.ckpt, map_location="cpu"), strict=True)
41 | model.freeze()
42 | model.to(device)
43 |
44 | assert os.path.isdir(args.input)
45 |
46 | pbar = tqdm(list_image_files(args.input, follow_links=True))
47 | for file_path in pbar:
48 | pbar.set_description(file_path)
49 | save_path = os.path.join(args.output, os.path.relpath(file_path, args.input))
50 | parent_path, stem, _ = get_file_name_parts(save_path)
51 | save_path = os.path.join(parent_path, f"{stem}.png")
52 | if os.path.exists(save_path):
53 | if args.skip_if_exist:
54 | print(f"skip {save_path}")
55 | continue
56 | else:
57 | raise RuntimeError(f"{save_path} already exist")
58 | os.makedirs(parent_path, exist_ok=True)
59 |
60 | # load low-quality image and resize
61 | lq = Image.open(file_path).convert("RGB")
62 | if args.sr_scale != 1:
63 | lq = lq.resize(
64 | tuple(int(x * args.sr_scale) for x in lq.size), Image.BICUBIC
65 | )
66 | lq_resized = auto_resize(lq, args.image_size)
67 | # padding
68 | x = pad(np.array(lq_resized), scale=64)
69 |
70 | x = torch.tensor(x, dtype=torch.float32, device=device) / 255.0
71 | x = x.permute(2, 0, 1).unsqueeze(0).contiguous()
72 | try:
73 | pred = model(x).detach().squeeze(0).permute(1, 2, 0) * 255
74 | pred = pred.clamp(0, 255).to(torch.uint8).cpu().numpy()
75 | except RuntimeError as e:
76 | print(f"inference failed, error: {e}")
77 | continue
78 |
79 | # remove padding
80 | pred = pred[:lq_resized.height, :lq_resized.width, :]
81 | if args.show_lq:
82 | if args.resize_back:
83 | lq = np.array(lq)
84 | if lq_resized.size != lq.size:
85 | pred = np.array(Image.fromarray(pred).resize(lq.size, Image.LANCZOS))
86 | else:
87 | lq = np.array(lq_resized)
88 | final_image = Image.fromarray(np.concatenate([lq, pred], axis=1))
89 | else:
90 | if args.resize_back and lq_resized.size != lq.size:
91 | final_image = Image.fromarray(pred).resize(lq.size, Image.LANCZOS)
92 | else:
93 | final_image = Image.fromarray(pred)
94 | final_image.save(save_path)
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/scripts/make_file_list.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | import os
4 | from argparse import ArgumentParser
5 |
6 | from utils.file import list_image_files
7 | import cv2
8 | import pdb
9 |
10 | parser = ArgumentParser()
11 | parser.add_argument("--img_folder", type=str, required=True)
12 | parser.add_argument("--val_size", type=int, default=0)
13 | parser.add_argument("--save_folder", type=str, required=True)
14 | parser.add_argument("--follow_links", action="store_true")
15 | args = parser.parse_args()
16 |
17 | files = list_image_files(
18 | args.img_folder, exts=(".jpg", ".png", ".jpeg"), follow_links=args.follow_links,
19 | log_progress=True, log_every_n_files=10000
20 | )
21 |
22 | print(f"find {len(files)} images in {args.img_folder}")
23 | assert args.val_size < len(files)
24 |
25 | # val_files = files[:args.val_size]
26 | # train_files = files[args.val_size:]
27 | val_files = files
28 | print(len(val_files))
29 | valid_files = []
30 | for i,path in enumerate(val_files):
31 |
32 | valid_files.append(path)
33 |
34 |
35 | print('Total files:{}'.format(len(valid_files)))
36 |
37 |
38 | os.makedirs(args.save_folder, exist_ok=True)
39 |
40 | # with open(os.path.join(args.save_folder, "train.list"), "w") as fp:
41 | # for file_path in train_files:
42 | # fp.write(f"{file_path}\n")
43 |
44 | with open(os.path.join(args.save_folder, "train.list"), "w") as fp:
45 | for file_path in valid_files:
46 | fp.write(f"{file_path}\n")
47 |
--------------------------------------------------------------------------------
/scripts/make_list_celea.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | import os
4 | from argparse import ArgumentParser
5 | import pandas as pd
6 |
7 | s_img = '/data1/zyx/CelebAMask-HQ/CelebA-HQ-img'
8 |
9 |
10 |
11 | train_count = 0
12 | test_count = 0
13 | val_count = 0
14 | image_list = pd.read_csv('/data1/zyx/CelebAMask-HQ/CelebA-HQ-to-CelebA-mapping.txt', delim_whitespace=True, header=None)
15 | train_paths = []
16 | val_paths = []
17 | test_paths = []
18 |
19 | for idx, x in enumerate(image_list.loc[:, 1]):
20 | print (idx, x)
21 | if idx == 0:
22 | continue
23 |
24 | x = int(x)
25 | if x >= 162771 and x < 182638:
26 | img_path = os.path.join(s_img, str(idx-1)+'.jpg')
27 | val_paths.append(img_path)
28 | val_count += 1
29 |
30 | elif x >= 182638:
31 | img_path = os.path.join(s_img, str(idx-1)+'.jpg')
32 | test_paths.append(img_path)
33 | test_count += 1
34 | else:
35 | img_path = os.path.join(s_img, str(idx-1)+'.jpg')
36 | train_paths.append(img_path)
37 | train_count += 1
38 |
39 | print (train_count + test_count + val_count)
40 |
41 |
42 |
43 |
44 | save_folder = '/data1/zyx/FFHQ512'
45 |
46 | # with open(os.path.join(save_folder, "train.list"), "a") as fp:
47 | # for file_path in train_paths:
48 | # fp.write(f"{file_path}\n")
49 |
50 | # with open(os.path.join(save_folder, "val.list"), "a") as fp:
51 | # for file_path in val_paths:
52 | # fp.write(f"{file_path}\n")
53 |
54 | with open(os.path.join(save_folder, "test.list"), "w") as fp:
55 | for file_path in test_paths:
56 | fp.write(f"{file_path}\n")
--------------------------------------------------------------------------------
/scripts/make_stage2_init_weight.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | from argparse import ArgumentParser
4 | from typing import Dict
5 |
6 | import torch
7 | from omegaconf import OmegaConf
8 |
9 | from utils.common import instantiate_from_config
10 | import os
11 | os.environ['CURL_CA_BUNDLE'] = ''
12 |
13 | def load_weight(weight_path: str) -> Dict[str, torch.Tensor]:
14 | weight = torch.load(weight_path)
15 | if "state_dict" in weight:
16 | weight = weight["state_dict"]
17 |
18 | pure_weight = {}
19 | for key, val in weight.items():
20 | if key.startswith("module."):
21 | key = key[len("module."):]
22 | pure_weight[key] = val
23 |
24 | return pure_weight
25 |
26 | parser = ArgumentParser()
27 | parser.add_argument("--cldm_config", type=str, required=True)
28 | parser.add_argument("--sd_weight", type=str, required=True)
29 | parser.add_argument("--swinir_weight", type=str, required=True)
30 | parser.add_argument("--output", type=str, required=True)
31 | args = parser.parse_args()
32 |
33 | model = instantiate_from_config(OmegaConf.load(args.cldm_config))
34 |
35 | sd_weights = load_weight(args.sd_weight)
36 | swinir_weights = load_weight(args.swinir_weight)
37 | scratch_weights = model.state_dict()
38 |
39 | init_weights = {}
40 | for weight_name in scratch_weights.keys():
41 | # find target pretrained weights for this weight
42 | if weight_name.startswith("control_"):
43 | suffix = weight_name[len("control_"):]
44 | target_name = f"model.diffusion_{suffix}"
45 | target_model_weights = sd_weights
46 | elif weight_name.startswith("preprocess_model."):
47 | suffix = weight_name[len("preprocess_model."):]
48 | target_name = suffix
49 | target_model_weights = swinir_weights
50 | elif weight_name.startswith("cond_encoder."):
51 | suffix = weight_name[len("cond_encoder."):]
52 | target_name = F"first_stage_model.{suffix}"
53 | target_model_weights = sd_weights
54 | else:
55 | target_name = weight_name
56 | target_model_weights = sd_weights
57 |
58 | # if target weight exist in pretrained model
59 | print(f"copy weights: {target_name} -> {weight_name}")
60 | if target_name in target_model_weights:
61 | # get pretrained weight
62 | target_weight = target_model_weights[target_name]
63 | target_shape = target_weight.shape
64 | model_shape = scratch_weights[weight_name].shape
65 | # if pretrained weight has the same shape with model weight, we make a copy
66 | if model_shape == target_shape:
67 | init_weights[weight_name] = target_weight.clone()
68 | # else we copy pretrained weight with additional channels initialized to zero
69 | else:
70 | newly_added_channels = model_shape[1] - target_shape[1]
71 | oc, _, h, w = target_shape
72 | zero_weight = torch.zeros((oc, newly_added_channels, h, w)).type_as(target_weight)
73 | init_weights[weight_name] = torch.cat((target_weight.clone(), zero_weight), dim=1)
74 | print(f"add zero weight to {target_name} in pretrained weights, newly added channels = {newly_added_channels}")
75 | else:
76 | init_weights[weight_name] = scratch_weights[weight_name].clone()
77 | print(f"These weights are newly added: {weight_name}")
78 |
79 | model.load_state_dict(init_weights, strict=True)
80 | torch.save(model.state_dict(), args.output)
81 | print("Done.")
82 |
--------------------------------------------------------------------------------
/scripts/merge_img.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | def merge_images_horizontally(image_paths, output_path):
4 | images = [Image.open(image_path) for image_path in image_paths]
5 |
6 | # Ensure that all images have the same height
7 | max_height = 512
8 | images = [img if img.height == max_height else img.resize((img.width * max_height // img.height, max_height)) for img in images]
9 |
10 | # Calculate the total width for the merged image
11 | total_width = sum(img.width for img in images)
12 |
13 | # Create a blank image with the total width and the maximum height
14 | merged_image = Image.new('RGB', (total_width, max_height))
15 |
16 | # Paste each image into the merged image
17 | x_offset = 0
18 | for img in images:
19 | merged_image.paste(img, (x_offset, 0))
20 | x_offset += img.width
21 |
22 | # Save the merged image
23 | merged_image.save(output_path)
24 |
25 | if __name__ == "__main__":
26 | image_paths_10955 = ["/home/user001/zwl/zyx/CodeFormer-master/inp6/lq/10955.png",
27 | "/home/user001/zwl/zyx/GPEN-main/eval_img_test/10955.png",
28 | "/home/user001/zwl/zyx/CodeFormer-master/inp6/10955.png",
29 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re3/z2/10955.png",
30 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re3/hq/10955.png"]
31 |
32 | image_paths_1777 = ["/home/user001/zwl/zyx/CodeFormer-master/inp11/lq/1777.png",
33 | "/home/user001/zwl/zyx/GPEN-main/eval_img_test/1777.png",
34 | "/home/user001/zwl/zyx/CodeFormer-master/inp11/1777.png",
35 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re10/z2/1777.png",
36 | "/home/user001/zwl/zyx/Diffbir/outputs/inp-re10/hq/1777.png"]
37 |
38 | output_path = "merged_image_10955.png"
39 |
40 | merge_images_horizontally(image_paths_10955, output_path)
41 |
--------------------------------------------------------------------------------
/scripts/metrics.py:
--------------------------------------------------------------------------------
1 | import pyiqa
2 | import torch
3 |
4 | # list all available metrics
5 | print(pyiqa.list_models())
6 |
7 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
8 |
9 |
10 |
11 | # For FID metric, use directory or precomputed statistics as inputs
12 | # refer to clean-fid for more details: https://github.com/GaParmar/clean-fid
13 | fid_metric = pyiqa.create_metric('fid')
14 | score = fid_metric('/home/zyx/CodeFormer-master/results/celeba_512_validation_lq_0.5/restored_faces/', '/data1/zyx/FFHQ512/FFHQ_512/')
15 | print(score)
16 | #score = fid_metric('./ResultsCalibra/dist_dir/', dataset_name="FFHQ", dataset_res=1024, dataset_split="trainval70k")
--------------------------------------------------------------------------------
/scripts/rainy.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | from os.path import join
5 | import pdb
6 | def compute_rainy_layer(derained_image_path, groundtruth_image_path, output_image_dir,filename):
7 | # Load derained and groundtruth images
8 | derained_image = cv2.imread(derained_image_path, cv2.IMREAD_GRAYSCALE)
9 |
10 | gt_image = cv2.imread(groundtruth_image_path, cv2.IMREAD_GRAYSCALE)
11 | w,h = derained_image.shape[0],derained_image.shape[1]
12 | gt_image = cv2.resize(gt_image,(h,w))
13 |
14 | # Apply thresholding to groundtruth image
15 | # Calculate the difference between the GT and result images
16 | difference_image = gt_image - derained_image
17 | #print(difference_image[:9,:9])
18 | #pdb.set_trace()
19 | # Apply a threshold to obtain the rainy layer
20 | threshold = 20 # You can adjust this value based on your application
21 | is_rainy = ((difference_image > threshold) * (difference_image <200))
22 | rainy_layer = np.where(is_rainy, 1, 0)
23 | # Save the rainy layer
24 | output_image_path = join(output_image_dir,filename)
25 | cv2.imwrite(output_image_path, rainy_layer * 255)
26 | print(output_image_path)
27 | #cv2.imwrite(output_image_path, normalized_difference)
28 |
29 | if __name__ == "__main__":
30 | derained_image_path = "/home/user001/zwl/zyx/Diffbir/outputs/swin_derain/rain-001.png"
31 | #derained_image_path = "/home/user001/zwl/zyx/Pretrained-IPT/experiment/results/ipt/results-DIV2K/rain-001_x1_SR.png"
32 | groundtruth_image_path = "/home/user001/zwl/data/Derain/Rain100L/rainy/rain-001.png"
33 | output_image_dir ="/home/user001/zwl/zyx/RCDNet-master/RCDNet_code/for_syn/experiment/RCDNet_test/results//rainy" # "/home/user001/zwl/zyx/Pretrained-IPT/experiment/results/ipt/results-DIV2K/rainy/"
34 | os.makedirs(output_image_dir,exist_ok=True)
35 | derained_image_dir = "/home/user001/zwl/zyx/RCDNet-master/RCDNet_code/for_syn/experiment/RCDNet_test/results/"
36 | groundtruth_image_dir = "/home/user001/zwl/data/Derain/Rain100L/rainy"
37 | for filename in os.listdir(derained_image_dir):
38 | if filename[-3:] != 'png':
39 | continue
40 | derain_image_path = join(derained_image_dir,filename)
41 | groundtruth_image_path = join(groundtruth_image_dir,filename)
42 | compute_rainy_layer(derain_image_path, groundtruth_image_path, output_image_dir,filename)
43 |
--------------------------------------------------------------------------------
/scripts/sample_dataset.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 | from argparse import ArgumentParser
4 | import os
5 | from typing import Any
6 |
7 | from omegaconf import OmegaConf
8 | from torch.utils.data import DataLoader
9 | import numpy as np
10 | from PIL import Image
11 | import pytorch_lightning as pl
12 |
13 | from utils.common import instantiate_from_config
14 |
15 |
16 | def wrap_dataloader(data_loader: DataLoader) -> Any:
17 | while True:
18 | yield from data_loader
19 |
20 |
21 | pl.seed_everything(231, workers=True)
22 |
23 | parser = ArgumentParser()
24 | parser.add_argument("--config", type=str, required=True)
25 | parser.add_argument("--sample_size", type=int, default=128)
26 | parser.add_argument("--show_gt", action="store_true")
27 | parser.add_argument("--output", type=str, required=True)
28 | args = parser.parse_args()
29 |
30 | config = OmegaConf.load(args.config)
31 | dataset = instantiate_from_config(config.dataset)
32 | transform = instantiate_from_config(config.batch_transform)
33 | data_loader = wrap_dataloader(DataLoader(dataset, batch_size=1, shuffle=True))
34 |
35 | cnt = 0
36 | os.makedirs(args.output, exist_ok=True)
37 |
38 | for batch in data_loader:
39 | batch = transform(batch)
40 | for hq, lq in zip(batch["jpg"], batch["hint"]):
41 | hq = ((hq + 1) * 127.5).numpy().clip(0, 255).astype(np.uint8)
42 | lq = (lq * 255.0).numpy().clip(0, 255).astype(np.uint8)
43 | if args.show_gt:
44 | Image.fromarray(np.concatenate([hq, lq], axis=1)).save(os.path.join(args.output, f"{cnt}.png"))
45 | else:
46 | Image.fromarray(lq).save(os.path.join(args.output, f"{cnt}.png"))
47 | cnt += 1
48 | if cnt >= args.sample_size:
49 | break
50 | if cnt >= args.sample_size:
51 | break
52 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # import os
2 | # os.environ['CUDA_VISIBLE_DEVICES'] = "7"
3 | from argparse import ArgumentParser
4 | import cv2
5 | import os
6 | #os.environ['CUDA_VISIBLE_DEVICE'] = '7'
7 | import pytorch_lightning as pl
8 | from omegaconf import OmegaConf
9 | import torch
10 | from tqdm import tqdm
11 | from utils.common import instantiate_from_config, load_state_dict
12 | from torchvision.utils import save_image
13 | import torch.distributed as dist
14 | from torch.nn.parallel import DistributedDataParallel
15 |
16 | def main() -> None:
17 |
18 | # dist.init_process_group(backend='nccl', init_method='env://')
19 | # rank = torch.cuda.set_device(dist.get_rank())
20 | # device = torch.device('cuda', rank)
21 |
22 | parser = ArgumentParser()
23 | parser.add_argument("--config", type=str, required=True)
24 | parser.add_argument("--output", type=str, required=True)
25 | parser.add_argument("--watch_step", action='store_true')
26 | # parser.add_argument("--local_rank", type=int, default=0)
27 | args = parser.parse_args()
28 |
29 | config = OmegaConf.load(args.config)
30 | pl.seed_everything(config.lightning.seed, workers=True)
31 | model_config = OmegaConf.load(config.model.config)
32 | #print(model_config)
33 | model_config['params']['output'] = args.output
34 |
35 | data_module = instantiate_from_config(config.data)
36 | data_module.setup(stage="fit")
37 | model = instantiate_from_config(model_config)
38 | if config.model.get("resume"):
39 | model_dict = torch.load(config.model.resume, map_location="cpu")
40 | model_dict = model_dict['state_dict'] if 'state_dict' in model_dict.keys() else model_dict
41 | a,b = model.load_state_dict(model_dict, strict=False)
42 | print("missing_keys:",a)
43 | print("unexpected_keys:",b)
44 | print("{} model has been loaded!".format(config.model.resume))
45 | load_state_dict(model.preprocess_model, torch.load('/home/user001/zwl/data/flowir_work_dirs/swin_inp0/lightning_logs/version_0/checkpoints/step=39999.ckpt', map_location="cpu"), strict=True)
46 | #swin_dict = torch.load(config.model.resume, map_location="cpu")
47 | #torch.cuda.empty_cache()
48 | # model.to(device)
49 | #model.cuda()
50 | #model = DistributedDataParallel(model,device_ids=[rank],output_device=rank)
51 | model.eval()
52 |
53 |
54 | save_path = args.output
55 | # final_path = os.path.join(save_path,"final")
56 | # midd_path = os.path.join(save_path,"midd")
57 | # os.makedirs(final_path,exist_ok=True)
58 | # os.makedirs(midd_path,exist_ok=True)
59 |
60 | callbacks = []
61 | for callback_config in config.lightning.callbacks:
62 | callbacks.append(instantiate_from_config(callback_config))
63 | trainer = pl.Trainer(callbacks=callbacks, **config.lightning.trainer)
64 |
65 | testloader = data_module.val_dataloader()
66 | # for batch_idx, batch in tqdm(enumerate(testloader),total=len(testloader)):
67 | # batch['jpg'] = batch['jpg'].to(device)
68 | # batch['hint'] = batch['hint'].to(device)
69 | # imgname_batch = batch['imgname']
70 | # log = model.module.validation_step(batch,args.watch_step)
71 | trainer.test(model,test_dataloaders=testloader)
72 | #assert False
73 | #images = log['samples']
74 | #images_midd = log['samples_3']
75 | # save_batch(images=images,
76 | # imgname_batch=imgname_batch,
77 | # save_path=final_path,
78 | # watch_step=False)
79 | # save_batch(images=images_midd,
80 | # imgname_batch=imgname_batch,
81 | # save_path=midd_path,
82 | # watch_step=False)
83 |
84 | def save_batch(images,imgname_batch, save_path, watch_step=False):
85 | if watch_step:
86 | for list_idx, img_list in enumerate(images):
87 | for img_idx, img in enumerate(img_list):
88 | imgname = str(list_idx)+"_"+imgname_batch[img_idx]
89 | save_img = os.path.join(save_path,imgname)
90 | save_image(img,save_img)
91 | else:
92 | for img_idx, img in enumerate(images):
93 | imgname = imgname_batch[img_idx]
94 | save_img = os.path.join(save_path,imgname)
95 | save_image(img,save_img)
96 |
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
101 | '''
102 | CUDA_VISIBLE_DEVICES=0 \
103 | python3 \
104 | test.py \
105 | --config configs/test_cldm.yaml \
106 | --output /home/zyx/DiffBIR-main/outputs/celebamaskhq_reflow_nolora/
107 |
108 | CUDA_VISIBLE_DEVICES=6 \
109 | python test.py \
110 | --config configs/test_cldm.yaml \
111 | --output outputs/reflow_celeba_lq_than_lq
112 | '''
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | import torch
6 |
7 | from utils.common import instantiate_from_config, load_state_dict
8 |
9 |
10 | def main() -> None:
11 | parser = ArgumentParser()
12 | parser.add_argument("--config", type=str, required=True)
13 | args = parser.parse_args()
14 |
15 | config = OmegaConf.load(args.config)
16 | pl.seed_everything(config.lightning.seed, workers=True)
17 |
18 | data_module = instantiate_from_config(config.data)
19 | model = instantiate_from_config(OmegaConf.load(config.model.config))
20 | # TODO: resume states saved in checkpoint.
21 | if config.model.get("resume"):
22 | weights = torch.load(config.model.resume, map_location="cpu")
23 | '''new_weights = {}
24 | for k in weights['state_dict']:
25 |
26 |
27 | if 'lora' not in k:
28 | new_weights[k] = weights['state_dict'][k]
29 | weights['state_dict'] = new_weights'''
30 | load_state_dict(model, weights, strict=False)
31 |
32 |
33 | #load_state_dict(model.preprocess_model, torch.load('/home/user001/zwl/data/flowir_work_dirs/swin_derain0/lightning_logs/version_6/checkpoints/step=69999.ckpt', map_location="cpu"), strict=True)
34 |
35 | callbacks = []
36 | for callback_config in config.lightning.callbacks:
37 | callbacks.append(instantiate_from_config(callback_config))
38 | trainer = pl.Trainer(callbacks=callbacks, **config.lightning.trainer)
39 | trainer.fit(model, datamodule=data_module)
40 | #trainer.test()
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/utils/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/degradation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/degradation.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/face_restoration_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/face_restoration_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/file.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/file.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/file.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/file.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/process.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/process.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping, Any
2 | import importlib
3 |
4 | from torch import nn
5 |
6 |
7 | def get_obj_from_str(string: str, reload: bool=False) -> object:
8 | module, cls = string.rsplit(".", 1)
9 | if reload:
10 | module_imp = importlib.import_module(module)
11 | importlib.reload(module_imp)
12 | return getattr(importlib.import_module(module, package=None), cls)
13 |
14 |
15 | def instantiate_from_config(config: Mapping[str, Any]) -> object:
16 | if not "target" in config:
17 | raise KeyError("Expected key `target` to instantiate.")
18 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
19 |
20 |
21 | def disabled_train(self: nn.Module) -> nn.Module:
22 | """Overwrite model.train with this function to make sure train/eval mode
23 | does not change anymore."""
24 | return self
25 |
26 |
27 | def frozen_module(module: nn.Module) -> None:
28 | module.eval()
29 | module.train = disabled_train
30 | for p in module.parameters():
31 | p.requires_grad = False
32 |
33 |
34 | def load_state_dict(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None:
35 | state_dict = state_dict.get("state_dict", state_dict)
36 |
37 | is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.")
38 | is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.")
39 |
40 | if (
41 | is_model_key_starts_with_module and
42 | (not is_state_dict_key_starts_with_module)
43 | ):
44 | state_dict = {f"module.{key}": value for key, value in state_dict.items()}
45 | if (
46 | (not is_model_key_starts_with_module) and
47 | is_state_dict_key_starts_with_module
48 | ):
49 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
50 |
51 | model.load_state_dict(state_dict, strict=strict)
52 |
--------------------------------------------------------------------------------
/utils/file.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Tuple
3 |
4 | from urllib.parse import urlparse
5 | from torch.hub import download_url_to_file, get_dir
6 |
7 |
8 | def load_file_list(file_list_path: str) -> List[str]:
9 | files = []
10 | # each line in file list contains a path of an image
11 | with open(file_list_path, "r") as fin:
12 | for line in fin:
13 | path = line.strip()
14 | if path:
15 | files.append(path)
16 | return files
17 |
18 |
19 | def list_image_files(
20 | img_dir: str,
21 | exts: Tuple[str]=(".jpg", ".png", ".jpeg",".arw"),
22 | follow_links: bool=False,
23 | log_progress: bool=False,
24 | log_every_n_files: int=10000,
25 | max_size: int=-1
26 | ) -> List[str]:
27 | files = []
28 | for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
29 | early_stop = False
30 | for file_name in file_names:
31 | if os.path.splitext(file_name)[1].lower() in exts:
32 | if max_size >= 0 and len(files) >= max_size:
33 | early_stop = True
34 | break
35 | files.append(os.path.join(dir_path, file_name))
36 | if log_progress and len(files) % log_every_n_files == 0:
37 | print(f"find {len(files)} images in {img_dir}")
38 | if early_stop:
39 | break
40 | return files
41 |
42 |
43 | def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
44 | parent_path, file_name = os.path.split(file_path)
45 | stem, ext = os.path.splitext(file_name)
46 | return parent_path, stem, ext
47 |
48 |
49 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
50 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
51 | """Load file form http url, will download models if necessary.
52 |
53 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
54 |
55 | Args:
56 | url (str): URL to be downloaded.
57 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
58 | Default: None.
59 | progress (bool): Whether to show the download progress. Default: True.
60 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
61 |
62 | Returns:
63 | str: The path to the downloaded file.
64 | """
65 | if model_dir is None: # use the pytorch hub_dir
66 | hub_dir = get_dir()
67 | model_dir = os.path.join(hub_dir, 'checkpoints')
68 |
69 | os.makedirs(model_dir, exist_ok=True)
70 |
71 | parts = urlparse(url)
72 | filename = os.path.basename(parts.path)
73 | if file_name is not None:
74 | filename = file_name
75 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
76 | if not os.path.exists(cached_file):
77 | print(f'Downloading: "{url}" to {cached_file}\n')
78 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
79 | return cached_file
80 |
--------------------------------------------------------------------------------
/utils/image/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffjpeg import DiffJPEG
2 | from .usm_sharp import USMSharp
3 | from .common import (
4 | random_crop_arr, center_crop_arr, augment,
5 | filter2D, rgb2ycbcr_pt, auto_resize, pad
6 | )
7 | from .align_color import (
8 | wavelet_reconstruction, adaptive_instance_normalization
9 | )
10 |
11 | __all__ = [
12 | "DiffJPEG",
13 |
14 | "USMSharp",
15 |
16 | "random_crop_arr",
17 | "center_crop_arr",
18 | "augment",
19 | "filter2D",
20 | "rgb2ycbcr_pt",
21 | "auto_resize",
22 | "pad",
23 |
24 | "wavelet_reconstruction",
25 | "adaptive_instance_normalization"
26 | ]
27 |
--------------------------------------------------------------------------------
/utils/image/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/align_color.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/align_color.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/align_color.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/align_color.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/diffjpeg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/diffjpeg.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/diffjpeg.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/diffjpeg.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/usm_sharp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/usm_sharp.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/image/__pycache__/usm_sharp.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/image/__pycache__/usm_sharp.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/image/align_color.py:
--------------------------------------------------------------------------------
1 | '''
2 | # --------------------------------------------------------------------------------
3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4 | # --------------------------------------------------------------------------------
5 | '''
6 |
7 | import torch
8 | from PIL import Image
9 | from torch import Tensor
10 | from torch.nn import functional as F
11 | from torchvision.transforms import ToTensor, ToPILImage
12 |
13 |
14 | def adain_color_fix(target: Image, source: Image):
15 | # Convert images to tensors
16 | to_tensor = ToTensor()
17 | target_tensor = to_tensor(target).unsqueeze(0)
18 | source_tensor = to_tensor(source).unsqueeze(0)
19 |
20 | # Apply adaptive instance normalization
21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22 |
23 | # Convert tensor back to image
24 | to_image = ToPILImage()
25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26 |
27 | return result_image
28 |
29 | def wavelet_color_fix(target: Image, source: Image):
30 | # Convert images to tensors
31 | to_tensor = ToTensor()
32 | target_tensor = to_tensor(target).unsqueeze(0)
33 | source_tensor = to_tensor(source).unsqueeze(0)
34 |
35 | # Apply wavelet reconstruction
36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37 |
38 | # Convert tensor back to image
39 | to_image = ToPILImage()
40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41 |
42 | return result_image
43 |
44 | def calc_mean_std(feat: Tensor, eps=1e-5):
45 | """Calculate mean and std for adaptive_instance_normalization.
46 | Args:
47 | feat (Tensor): 4D tensor.
48 | eps (float): A small value added to the variance to avoid
49 | divide-by-zero. Default: 1e-5.
50 | """
51 | size = feat.size()
52 | assert len(size) == 4, 'The input feature should be 4D tensor.'
53 | b, c = size[:2]
54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57 | return feat_mean, feat_std
58 |
59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60 | """Adaptive instance normalization.
61 | Adjust the reference features to have the similar color and illuminations
62 | as those in the degradate features.
63 | Args:
64 | content_feat (Tensor): The reference feature.
65 | style_feat (Tensor): The degradate features.
66 | """
67 | size = content_feat.size()
68 | style_mean, style_std = calc_mean_std(style_feat)
69 | content_mean, content_std = calc_mean_std(content_feat)
70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72 |
73 | def wavelet_blur(image: Tensor, radius: int):
74 | """
75 | Apply wavelet blur to the input tensor.
76 | """
77 | # input shape: (1, 3, H, W)
78 | # convolution kernel
79 | kernel_vals = [
80 | [0.0625, 0.125, 0.0625],
81 | [0.125, 0.25, 0.125],
82 | [0.0625, 0.125, 0.0625],
83 | ]
84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85 | # add channel dimensions to the kernel to make it a 4D tensor
86 | kernel = kernel[None, None]
87 | # repeat the kernel across all input channels
88 | kernel = kernel.repeat(3, 1, 1, 1)
89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90 | # apply convolution
91 | output = F.conv2d(image, kernel, groups=3, dilation=radius)
92 | return output
93 |
94 | def wavelet_decomposition(image: Tensor, levels=5):
95 | """
96 | Apply wavelet decomposition to the input tensor.
97 | This function only returns the low frequency & the high frequency.
98 | """
99 | high_freq = torch.zeros_like(image)
100 | for i in range(levels):
101 | radius = 2 ** i
102 | low_freq = wavelet_blur(image, radius)
103 | high_freq += (image - low_freq)
104 | image = low_freq
105 |
106 | return high_freq, low_freq
107 |
108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109 | """
110 | Apply wavelet decomposition, so that the content will have the same color as the style.
111 | """
112 | # calculate the wavelet decomposition of the content feature
113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114 | del content_low_freq
115 | # calculate the wavelet decomposition of the style feature
116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117 | del style_high_freq
118 | # reconstruct the content feature with the style's high frequency
119 | return content_high_freq + style_low_freq
--------------------------------------------------------------------------------
/utils/image/usm_sharp.py:
--------------------------------------------------------------------------------
1 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 | from .common import filter2D
7 |
8 |
9 | class USMSharp(torch.nn.Module):
10 |
11 | def __init__(self, radius=50, sigma=0):
12 | super(USMSharp, self).__init__()
13 | if radius % 2 == 0:
14 | radius += 1
15 | self.radius = radius
16 | kernel = cv2.getGaussianKernel(radius, sigma)
17 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
18 | self.register_buffer('kernel', kernel)
19 |
20 | def forward(self, img, weight=0.5, threshold=10):
21 | blur = filter2D(img, self.kernel)
22 | residual = img - blur
23 |
24 | mask = torch.abs(residual) * 255 > threshold
25 | mask = mask.float()
26 | soft_mask = filter2D(mask, self.kernel)
27 | sharp = img + weight * residual
28 | sharp = torch.clip(sharp, 0, 1)
29 | return soft_mask * sharp + (1 - soft_mask) * img
30 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lpips
3 |
4 | from .image import rgb2ycbcr_pt
5 | from .common import frozen_module
6 |
7 |
8 | # https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/metrics/psnr_ssim.py#L52
9 | def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False):
10 | """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
11 |
12 | Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
13 |
14 | Args:
15 | img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
16 | img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
17 | crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
19 |
20 | Returns:
21 | float: PSNR result.
22 | """
23 |
24 | assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
25 |
26 | if crop_border != 0:
27 | img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
28 | img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
29 |
30 | if test_y_channel:
31 | img = rgb2ycbcr_pt(img, y_only=True)
32 | img2 = rgb2ycbcr_pt(img2, y_only=True)
33 |
34 | img = img.to(torch.float64)
35 | img2 = img2.to(torch.float64)
36 |
37 | mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
38 | return 10. * torch.log10(1. / (mse + 1e-8))
39 |
40 |
41 | class LPIPS:
42 |
43 | def __init__(self, net: str) -> None:
44 | self.model = lpips.LPIPS(net=net)
45 | frozen_module(self.model)
46 |
47 | @torch.no_grad()
48 | def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool) -> torch.Tensor:
49 | """
50 | Compute LPIPS.
51 |
52 | Args:
53 | img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input
54 | image is range in [0, 1].
55 | img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input
56 | image is range in [0, 1].
57 | normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1].
58 |
59 | Returns:
60 | lpips_values (torch.Tensor): The lpips scores of this batch.
61 | """
62 | return self.model(img1, img2, normalize=normalize)
63 |
64 | def to(self, device: str) -> "LPIPS":
65 | self.model.to(device)
66 | return self
67 |
--------------------------------------------------------------------------------
/utils/pickout_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | if __name__ == "__main__":
5 | gttest = "outputs/gtcelebamaskhq"
6 | os.makedirs(gttest,exist_ok=True)
7 | with open("/data1/zyx/FFHQ512/test.list", 'r') as f:
8 | lines = f.readlines()
9 | for line in lines:
10 | path = line[:-1]
11 | name = path.split('/')[-1]
12 | shutil.copy(path,os.path.join(gttest,name))
13 |
14 |
--------------------------------------------------------------------------------
/utils/process.py:
--------------------------------------------------------------------------------
1 | """Forward processing of raw data to sRGB images.
2 |
3 | Unprocessing Images for Learned Raw Denoising
4 | http://timothybrooks.com/tech/unprocessing
5 | """
6 |
7 | import numpy as np
8 | import torch
9 | from torchinterp1d import Interp1d
10 | from os.path import join
11 |
12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | def apply_gains(bayer_images, wbs):
16 | """Applies white balance to a batch of Bayer images."""
17 | N, C, _, _ = bayer_images.shape
18 | outs = bayer_images * wbs.view(N, C, 1, 1)
19 | return outs
20 |
21 |
22 | def apply_ccms(images, ccms):
23 | """Applies color correction matrices."""
24 | images = images.permute(
25 | 0, 2, 3, 1) # Permute the image tensor to BxHxWxC format from BxCxHxW format
26 | images = images[:, :, :, None, :]
27 | ccms = ccms[:, None, None, :, :]
28 | outs = torch.sum(images * ccms, dim=-1)
29 | # Re-Permute the tensor back to BxCxHxW format
30 | outs = outs.permute(0, 3, 1, 2)
31 | return outs
32 |
33 |
34 | def gamma_compression(images, gamma=2.2):
35 | """Converts from linear to gamma space."""
36 | outs = torch.clamp(images, min=1e-8) ** (1 / gamma)
37 | # outs = (1 + gamma[0]) * np.power(images, 1.0/gamma[1]) - gamma[0] + gamma[2]*images
38 | outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255
39 | return outs
40 |
41 |
42 | def binning(bayer_images):
43 | """RGBG -> RGB"""
44 | lin_rgb = torch.stack([
45 | bayer_images[:,0,...],
46 | torch.mean(bayer_images[:, [1,3], ...], dim=1),
47 | bayer_images[:,2,...]], dim=1)
48 |
49 | return lin_rgb
50 |
51 |
52 | def process(bayer_images, wbs, cam2rgbs, gamma=2.2, CRF=None):
53 | """Processes a batch of Bayer RGBG images into sRGB images."""
54 | # White balance.
55 | bayer_images = apply_gains(bayer_images, wbs)
56 | # Binning
57 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0)
58 | images = binning(bayer_images)
59 | # Color correction.
60 | images = apply_ccms(images, cam2rgbs)
61 | # Gamma compression.
62 | images = torch.clamp(images, min=0.0, max=1.0)
63 | if CRF is None:
64 | images = gamma_compression(images, gamma)
65 | else:
66 | images = camera_response_function(images, CRF)
67 |
68 | return images
69 |
70 |
71 | def camera_response_function(images, CRF):
72 | E, fs = CRF # unpack CRF data
73 |
74 | outs = torch.zeros_like(images)
75 | device = images.device
76 |
77 | for i in range(images.shape[0]):
78 | img = images[i].view(3, -1)
79 | out = Interp1d()(E.to(device), fs.to(device), img)
80 | outs[i, ...] = out.view(3, images.shape[2], images.shape[3])
81 |
82 | outs = torch.clamp((outs*255).int(), min=0, max=255).float() / 255
83 | return outs
84 |
85 |
86 | def raw2rgb(packed_raw, raw, CRF=None, gamma=2.2):
87 | """Raw2RGB pipeline (preprocess version)"""
88 | wb = np.array(raw.camera_whitebalance)
89 | wb /= wb[1]
90 | cam2rgb = raw.rgb_camera_matrix[:3, :3]
91 |
92 | if isinstance(packed_raw, np.ndarray):
93 | packed_raw = torch.from_numpy(packed_raw).float()
94 |
95 | wb = torch.from_numpy(wb).float().to(packed_raw.device)
96 | cam2rgb = torch.from_numpy(cam2rgb).float().to(packed_raw.device)
97 |
98 | out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy()
99 |
100 | return out
101 |
102 |
103 | def raw2rgb_v2(packed_raw, wb, ccm, CRF=None, gamma=2.2): # RGBG
104 | packed_raw = torch.from_numpy(packed_raw).float()
105 | wb = torch.from_numpy(wb).float()
106 | cam2rgb = torch.from_numpy(ccm).float()
107 | out = process(packed_raw[None], wbs=wb[None], cam2rgbs=cam2rgb[None], gamma=gamma, CRF=CRF)[0, ...].numpy()
108 | return out
109 |
110 |
111 | def raw2rgb_postprocess(packed_raw, raw, CRF=None):
112 | """Raw2RGB pipeline (postprocess version)"""
113 | assert packed_raw.ndimension() == 4 and packed_raw.shape[0] == 1
114 | wb = np.array(raw.camera_whitebalance)
115 | wb /= wb[1]
116 | cam2rgb = raw.rgb_camera_matrix[:3, :3]
117 |
118 | wb = torch.from_numpy(wb[None]).float().to(packed_raw.device)
119 | cam2rgb = torch.from_numpy(cam2rgb[None]).float().to(packed_raw.device)
120 | out = process(packed_raw, wbs=wb, cam2rgbs=cam2rgb, gamma=2.2, CRF=CRF)
121 | return out
122 |
123 |
124 | def read_wb_ccm(raw):
125 | wb = np.array(raw.camera_whitebalance)
126 | wb /= wb[1]
127 | wb = wb.astype(np.float32)
128 | ccm = raw.rgb_camera_matrix[:3, :3].astype(np.float32)
129 | return wb, ccm
130 |
131 |
132 | def read_emor(address):
133 | def _read_curve(lst):
134 | curve = [l.strip() for l in lst]
135 | curve = ' '.join(curve)
136 | curve = np.array(curve.split()).astype(np.float32)
137 | return curve
138 |
139 | with open(address) as f:
140 | lines = f.readlines()
141 | k = 1
142 | E = _read_curve(lines[k:k+256])
143 | k += 257
144 | f0 = _read_curve(lines[k:k+256])
145 | hs = []
146 | for _ in range(25):
147 | k += 257
148 | hs.append(_read_curve(lines[k:k+256]))
149 |
150 | hs = np.array(hs)
151 |
152 | return E, f0, hs
153 |
154 |
155 | def read_dorf(address):
156 | with open(address) as f:
157 | lines = f.readlines()
158 | curve_names = lines[0::6]
159 | Es = lines[3::6]
160 | Bs = lines[5::6]
161 |
162 | Es = [np.array(E.strip().split()).astype(np.float32) for E in Es]
163 | Bs = [np.array(B.strip().split()).astype(np.float32) for B in Bs]
164 |
165 | return curve_names, Es, Bs
166 |
167 |
168 | def load_CRF():
169 | # init CRF function
170 | fs = np.loadtxt(join('EMoR', 'CRF_SonyA7S2_5.txt'))
171 | E, _, _ = read_emor(join('EMoR', 'emor.txt'))
172 | E = torch.from_numpy(E).repeat(3, 1)
173 | fs = torch.from_numpy(fs)
174 | CRF = (E, fs)
175 | return CRF
--------------------------------------------------------------------------------
/utils/realesrgan/__pycache__/realesrganer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/realesrgan/__pycache__/realesrganer.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/realesrgan/__pycache__/rrdbnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/realesrgan/__pycache__/rrdbnet.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/realesrgan/rrdbnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | from torch.nn import init as init
5 | from torch.nn.modules.batchnorm import _BatchNorm
6 |
7 |
8 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
9 | """Initialize network weights.
10 |
11 | Args:
12 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
13 | scale (float): Scale initialized weights, especially for residual
14 | blocks. Default: 1.
15 | bias_fill (float): The value to fill bias. Default: 0
16 | kwargs (dict): Other arguments for initialization function.
17 | """
18 | if not isinstance(module_list, list):
19 | module_list = [module_list]
20 | for module in module_list:
21 | for m in module.modules():
22 | if isinstance(m, nn.Conv2d):
23 | init.kaiming_normal_(m.weight, **kwargs)
24 | m.weight.data *= scale
25 | if m.bias is not None:
26 | m.bias.data.fill_(bias_fill)
27 | elif isinstance(m, nn.Linear):
28 | init.kaiming_normal_(m.weight, **kwargs)
29 | m.weight.data *= scale
30 | if m.bias is not None:
31 | m.bias.data.fill_(bias_fill)
32 | elif isinstance(m, _BatchNorm):
33 | init.constant_(m.weight, 1)
34 | if m.bias is not None:
35 | m.bias.data.fill_(bias_fill)
36 |
37 |
38 | def make_layer(basic_block, num_basic_block, **kwarg):
39 | """Make layers by stacking the same blocks.
40 |
41 | Args:
42 | basic_block (nn.module): nn.module class for basic block.
43 | num_basic_block (int): number of blocks.
44 |
45 | Returns:
46 | nn.Sequential: Stacked blocks in nn.Sequential.
47 | """
48 | layers = []
49 | for _ in range(num_basic_block):
50 | layers.append(basic_block(**kwarg))
51 | return nn.Sequential(*layers)
52 |
53 |
54 | # TODO: may write a cpp file
55 | def pixel_unshuffle(x, scale):
56 | """ Pixel unshuffle.
57 |
58 | Args:
59 | x (Tensor): Input feature with shape (b, c, hh, hw).
60 | scale (int): Downsample ratio.
61 |
62 | Returns:
63 | Tensor: the pixel unshuffled feature.
64 | """
65 | b, c, hh, hw = x.size()
66 | out_channel = c * (scale**2)
67 | assert hh % scale == 0 and hw % scale == 0
68 | h = hh // scale
69 | w = hw // scale
70 | x_view = x.view(b, c, h, scale, w, scale)
71 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
72 |
73 |
74 | class ResidualDenseBlock(nn.Module):
75 | """Residual Dense Block.
76 |
77 | Used in RRDB block in ESRGAN.
78 |
79 | Args:
80 | num_feat (int): Channel number of intermediate features.
81 | num_grow_ch (int): Channels for each growth.
82 | """
83 |
84 | def __init__(self, num_feat=64, num_grow_ch=32):
85 | super(ResidualDenseBlock, self).__init__()
86 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
87 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
88 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
89 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
90 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
91 |
92 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
93 |
94 | # initialization
95 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
96 |
97 | def forward(self, x):
98 | x1 = self.lrelu(self.conv1(x))
99 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
100 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
101 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
102 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
103 | # Empirically, we use 0.2 to scale the residual for better performance
104 | return x5 * 0.2 + x
105 |
106 |
107 | class RRDB(nn.Module):
108 | """Residual in Residual Dense Block.
109 |
110 | Used in RRDB-Net in ESRGAN.
111 |
112 | Args:
113 | num_feat (int): Channel number of intermediate features.
114 | num_grow_ch (int): Channels for each growth.
115 | """
116 |
117 | def __init__(self, num_feat, num_grow_ch=32):
118 | super(RRDB, self).__init__()
119 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
120 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
121 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
122 |
123 | def forward(self, x):
124 | out = self.rdb1(x)
125 | out = self.rdb2(out)
126 | out = self.rdb3(out)
127 | # Empirically, we use 0.2 to scale the residual for better performance
128 | return out * 0.2 + x
129 |
130 |
131 | class RRDBNet(nn.Module):
132 | """Networks consisting of Residual in Residual Dense Block, which is used
133 | in ESRGAN.
134 |
135 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
136 |
137 | We extend ESRGAN for scale x2 and scale x1.
138 | Note: This is one option for scale 1, scale 2 in RRDBNet.
139 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
140 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
141 |
142 | Args:
143 | num_in_ch (int): Channel number of inputs.
144 | num_out_ch (int): Channel number of outputs.
145 | num_feat (int): Channel number of intermediate features.
146 | Default: 64
147 | num_block (int): Block number in the trunk network. Defaults: 23
148 | num_grow_ch (int): Channels for each growth. Default: 32.
149 | """
150 |
151 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
152 | super(RRDBNet, self).__init__()
153 | self.scale = scale
154 | if scale == 2:
155 | num_in_ch = num_in_ch * 4
156 | elif scale == 1:
157 | num_in_ch = num_in_ch * 16
158 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
159 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
160 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
161 | # upsample
162 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
163 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
164 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
165 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
166 |
167 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
168 |
169 | def forward(self, x):
170 | if self.scale == 2:
171 | feat = pixel_unshuffle(x, scale=2)
172 | elif self.scale == 1:
173 | feat = pixel_unshuffle(x, scale=4)
174 | else:
175 | feat = x
176 | feat = self.conv_first(feat)
177 | body_feat = self.conv_body(self.body(feat))
178 | feat = feat + body_feat
179 | # upsample
180 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
181 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
182 | out = self.conv_last(self.lrelu(self.conv_hr(feat)))
183 | return out
--------------------------------------------------------------------------------
/utils/torchinterp1d-master.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/torchinterp1d-master.zip
--------------------------------------------------------------------------------
/utils/torchinterp1d/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/.gitmodules:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/utils/torchinterp1d/.gitmodules
--------------------------------------------------------------------------------
/utils/torchinterp1d/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, Inria (Antoine Liutkus)
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/README.md:
--------------------------------------------------------------------------------
1 | # torchinterp1d
2 | ## CUDA 1-D interpolation for Pytorch
3 |
4 | Requires PyTorch >= 1.6 (due to [torch.searchsorted](https://pytorch.org/docs/master/generated/torch.searchsorted.html)).
5 |
6 | ## Presentation
7 |
8 | This repository implements an `interp1d` function that overrides torch.autograd.Function, enabling
9 | linear 1D interpolation on the GPU for Pytorch.
10 |
11 | ```
12 | def interp1d(x, y, xnew, out=None)
13 | ```
14 |
15 | This function returns interpolated values of a set of 1-D functions at the desired query points `xnew`.
16 |
17 | It works similarly to Matlab™ or scipy functions with
18 | the `linear` interpolation mode on, except that it parallelises over any number of desired interpolation problems and exploits CUDA on the GPU
19 |
20 | ### Parameters for `interp1d`
21 |
22 | * `x` : a (N, ) or (D, N) Pytorch Tensor:
23 | Either 1-D or 2-D. It contains the coordinates of the observed samples.
24 |
25 | * `y` : (N,) or (D, N) Pytorch Tensor.
26 | Either 1-D or 2-D. It contains the actual values that correspond to the coordinates given by `x`.
27 | The length of `y` along its last dimension must be the same as that of `x`
28 |
29 | * `xnew` : (P,) or (D, P) Pytorch Tensor.
30 | Either 1-D or 2-D. If it is not 1-D, its length along the first dimension must be the same as that of whichever `x` and `y` is 2-D. x-coordinates for which we want the interpolated output.
31 |
32 | * `out` : (D, P) Pytorch Tensor`
33 | Tensor for the output. If None: allocated automatically.
34 |
35 | ### Results
36 |
37 | a Pytorch tensor of shape (D, P), containing the interpolated values.
38 |
39 | ## Installation
40 |
41 | Type `pip install -e .` in the root folder of this repo.
42 |
43 | ## Usage
44 |
45 | Basically simply calle `torchinterp1d.interp1d`.
46 |
47 | Try out `python test.py` in the `examples` folder.
48 | ```
49 | Solving 100000 interpolation problems: each with 100 observations and 30 desired values
50 | CPU: 8060.260ms, GPU: 70.735ms, error: 0.000000%.
51 | ```
52 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/examples/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import time
4 | import numpy as np
5 | from torchinterp1d import interp1d
6 |
7 |
8 | if __name__ == "__main__":
9 | # defining the number of tests
10 | ntests = 2
11 |
12 | # problem dimensions
13 | D = 1000
14 | Dnew = 1
15 | N = 100
16 | P = 30
17 |
18 | yq_gpu = None
19 | yq_cpu = None
20 | for ntest in range(ntests):
21 | # draw the data
22 | x = torch.rand(D, N) * 10000
23 | x = x.sort(dim=1)[0]
24 |
25 | y = torch.linspace(0, 1000, D*N).view(D, -1)
26 | y -= y[:, 0, None]
27 |
28 | xnew = torch.rand(Dnew, P)*10000
29 |
30 | print('Solving %d interpolation problems: '
31 | 'each with %d observations and %d desired values' % (D, N, P))
32 |
33 | # calling the cpu version
34 | t0_cpu = time.time()
35 | yq_cpu = interp1d(x, y, xnew, yq_cpu)
36 | t1_cpu = time.time()
37 |
38 | display_str = 'CPU: %0.3fms, ' % ((t1_cpu-t0_cpu)*1000)
39 |
40 | if torch.cuda.is_available():
41 | x = x.to('cuda')
42 | y = y.to('cuda')
43 | xnew = xnew.to('cuda')
44 |
45 | # launching the cuda version
46 | t0 = time.time()
47 | yq_gpu = interp1d(x, y, xnew, yq_gpu)
48 | t1 = time.time()
49 |
50 | # compute the difference between both
51 | error = torch.norm(
52 | yq_cpu - yq_gpu.to('cpu'))/torch.norm(yq_cpu)*100.
53 |
54 | display_str += 'GPU: %0.3fms, error: %f%%.' % (
55 | (t1-t0)*1000, error)
56 | print(display_str)
57 |
58 | if torch.cuda.is_available():
59 | # for the last test, plot the result for the first 10 dimensions max
60 | d_plot = min(D, 10)
61 | x = x[:d_plot].cpu().numpy()
62 | y = y[:d_plot].cpu().numpy()
63 | xnew = xnew[:d_plot].cpu().numpy()
64 | yq_cpu = yq_cpu[:d_plot].cpu().numpy()
65 | yq_gpu = yq_gpu[:d_plot].cpu().numpy()
66 |
67 | plt.plot(x.T, y.T, '-',
68 | xnew.T, yq_gpu.T, 'o',
69 | xnew.T, yq_cpu.T, 'x')
70 | not_close = np.nonzero(np.invert(np.isclose(yq_gpu, yq_cpu)))
71 | if not_close[0].size:
72 | plt.scatter(xnew[not_close].T, yq_cpu[not_close].T,
73 | edgecolors='r', s=100, facecolors='none')
74 | plt.grid(True)
75 | plt.show()
76 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | # To use a consistent encoding
3 | from codecs import open
4 | from os import path
5 |
6 | # trying to import the required torch package
7 | try:
8 | import torch
9 | except ImportError:
10 | raise Exception('qsketch requires PyTorch to be installed. aborting')
11 |
12 | here = path.abspath(path.dirname(__file__))
13 |
14 | # Get the long description from the README file
15 | with open(path.join(here, 'README.md'), encoding='utf-8') as f:
16 | long_description = f.read()
17 |
18 | # Proceed to setup
19 | setup(
20 | name='torchinterp1d',
21 | version='1.1',
22 | description='An interp1d implementation for pytorch',
23 | long_description=long_description,
24 | long_description_content_type='text/markdown',
25 | author='Antoine Liutkus',
26 | author_email='antoine.liutkus@inria.fr',
27 | packages=['torchinterp1d'],
28 | keywords='interp1d torch',
29 | install_requires=[
30 | 'torch>=1.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/torchinterp1d/__init__.py:
--------------------------------------------------------------------------------
1 | from .interp1d import interp1d, Interp1d
2 |
--------------------------------------------------------------------------------
/utils/torchinterp1d/torchinterp1d/interp1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import contextlib
3 |
4 | class Interp1d(torch.autograd.Function):
5 | @staticmethod
6 | def forward(ctx, x, y, xnew, out=None):
7 | """
8 | Linear 1D interpolation on the GPU for Pytorch.
9 | This function returns interpolated values of a set of 1-D functions at
10 | the desired query points `xnew`.
11 | This function is working similarly to Matlab™ or scipy functions with
12 | the `linear` interpolation mode on, except that it parallelises over
13 | any number of desired interpolation problems.
14 | The code will run on GPU if all the tensors provided are on a cuda
15 | device.
16 |
17 | Parameters
18 | ----------
19 | x : (N, ) or (D, N) Pytorch Tensor
20 | A 1-D or 2-D tensor of real values.
21 | y : (N,) or (D, N) Pytorch Tensor
22 | A 1-D or 2-D tensor of real values. The length of `y` along its
23 | last dimension must be the same as that of `x`
24 | xnew : (P,) or (D, P) Pytorch Tensor
25 | A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
26 | _both_ `x` and `y` are 1-D. Otherwise, its length along the first
27 | dimension must be the same as that of whichever `x` and `y` is 2-D.
28 | out : Pytorch Tensor, same shape as `xnew`
29 | Tensor for the output. If None: allocated automatically.
30 |
31 | """
32 | # making the vectors at least 2D
33 | is_flat = {}
34 | require_grad = {}
35 | v = {}
36 | device = []
37 | eps = torch.finfo(y.dtype).eps
38 | for name, vec in {'x': x, 'y': y, 'xnew': xnew}.items():
39 | assert len(vec.shape) <= 2, 'interp1d: all inputs must be '\
40 | 'at most 2-D.'
41 | if len(vec.shape) == 1:
42 | v[name] = vec[None, :]
43 | else:
44 | v[name] = vec
45 | is_flat[name] = v[name].shape[0] == 1
46 | require_grad[name] = vec.requires_grad
47 | device = list(set(device + [str(vec.device)]))
48 | assert len(device) == 1, 'All parameters must be on the same device.'
49 | device = device[0]
50 |
51 | # Checking for the dimensions
52 | assert (v['x'].shape[1] == v['y'].shape[1]
53 | and (
54 | v['x'].shape[0] == v['y'].shape[0]
55 | or v['x'].shape[0] == 1
56 | or v['y'].shape[0] == 1
57 | )
58 | ), ("x and y must have the same number of columns, and either "
59 | "the same number of row or one of them having only one "
60 | "row.")
61 |
62 | reshaped_xnew = False
63 | if ((v['x'].shape[0] == 1) and (v['y'].shape[0] == 1)
64 | and (v['xnew'].shape[0] > 1)):
65 | # if there is only one row for both x and y, there is no need to
66 | # loop over the rows of xnew because they will all have to face the
67 | # same interpolation problem. We should just stack them together to
68 | # call interp1d and put them back in place afterwards.
69 | original_xnew_shape = v['xnew'].shape
70 | v['xnew'] = v['xnew'].contiguous().view(1, -1)
71 | reshaped_xnew = True
72 |
73 | # identify the dimensions of output and check if the one provided is ok
74 | D = max(v['x'].shape[0], v['xnew'].shape[0])
75 | shape_ynew = (D, v['xnew'].shape[-1])
76 | if out is not None:
77 | if out.numel() != shape_ynew[0]*shape_ynew[1]:
78 | # The output provided is of incorrect shape.
79 | # Going for a new one
80 | out = None
81 | else:
82 | ynew = out.reshape(shape_ynew)
83 | if out is None:
84 | ynew = torch.zeros(*shape_ynew, device=device)
85 |
86 | # moving everything to the desired device in case it was not there
87 | # already (not handling the case things do not fit entirely, user will
88 | # do it if required.)
89 | for name in v:
90 | v[name] = v[name].to(device)
91 |
92 | # calling searchsorted on the x values.
93 | ind = ynew.long()
94 |
95 | # expanding xnew to match the number of rows of x in case only one xnew is
96 | # provided
97 | if v['xnew'].shape[0] == 1:
98 | v['xnew'] = v['xnew'].expand(v['x'].shape[0], -1)
99 |
100 | # the squeeze is because torch.searchsorted does accept either a nd with
101 | # matching shapes for x and xnew or a 1d vector for x. Here we would
102 | # have (1,len) for x sometimes
103 | torch.searchsorted(v['x'].contiguous().squeeze(),
104 | v['xnew'].contiguous(), out=ind)
105 |
106 | # the `-1` is because searchsorted looks for the index where the values
107 | # must be inserted to preserve order. And we want the index of the
108 | # preceeding value.
109 | ind -= 1
110 | # we clamp the index, because the number of intervals is x.shape-1,
111 | # and the left neighbour should hence be at most number of intervals
112 | # -1, i.e. number of columns in x -2
113 | ind = torch.clamp(ind, 0, v['x'].shape[1] - 1 - 1)
114 |
115 | # helper function to select stuff according to the found indices.
116 | def sel(name):
117 | if is_flat[name]:
118 | return v[name].contiguous().view(-1)[ind]
119 | return torch.gather(v[name], 1, ind)
120 |
121 | # activating gradient storing for everything now
122 | enable_grad = False
123 | saved_inputs = []
124 | for name in ['x', 'y', 'xnew']:
125 | if require_grad[name]:
126 | enable_grad = True
127 | saved_inputs += [v[name]]
128 | else:
129 | saved_inputs += [None, ]
130 | # assuming x are sorted in the dimension 1, computing the slopes for
131 | # the segments
132 | is_flat['slopes'] = is_flat['x']
133 | # now we have found the indices of the neighbors, we start building the
134 | # output. Hence, we start also activating gradient tracking
135 | with torch.enable_grad() if enable_grad else contextlib.suppress():
136 | v['slopes'] = (
137 | (v['y'][:, 1:]-v['y'][:, :-1])
138 | /
139 | (eps + (v['x'][:, 1:]-v['x'][:, :-1]))
140 | )
141 |
142 | # now build the linear interpolation
143 | ynew = sel('y') + sel('slopes')*(
144 | v['xnew'] - sel('x'))
145 |
146 | if reshaped_xnew:
147 | ynew = ynew.view(original_xnew_shape)
148 |
149 | ctx.save_for_backward(ynew, *saved_inputs)
150 | return ynew
151 |
152 | @staticmethod
153 | def backward(ctx, grad_out):
154 | inputs = ctx.saved_tensors[1:]
155 | gradients = torch.autograd.grad(
156 | ctx.saved_tensors[0],
157 | [i for i in inputs if i is not None],
158 | grad_out, retain_graph=True)
159 | result = [None, ] * 5
160 | pos = 0
161 | for index in range(len(inputs)):
162 | if inputs[index] is not None:
163 | result[index] = gradients[pos]
164 | pos += 1
165 | return (*result,)
166 |
167 |
168 | interp1d = Interp1d.apply
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import time
5 | import math
6 | import torch
7 | import numpy as np
8 | import scipy
9 | import scipy.io as spio
10 | import yaml
11 | from PIL import Image
12 | import torch.nn.functional as F
13 |
14 | def loadmat(filename):
15 | '''
16 | this function should be called instead of direct spio.loadmat
17 | as it cures the problem of not properly recovering python dictionaries
18 | from mat files. It calls the function check keys to cure all entries
19 | which are still mat-objects
20 | '''
21 | def _check_keys(d):
22 | '''
23 | checks if entries in dictionary are mat-objects. If yes
24 | todict is called to change them to nested dictionaries
25 | '''
26 | for key in d:
27 | if isinstance(d[key], spio.matlab.mio5_params.mat_struct):
28 | d[key] = _todict(d[key])
29 | return d
30 |
31 | def _todict(matobj):
32 | '''
33 | A recursive function which constructs from matobjects nested dictionaries
34 | '''
35 | d = {}
36 | for strg in matobj._fieldnames:
37 | elem = matobj.__dict__[strg]
38 | if isinstance(elem, spio.matlab.mio5_params.mat_struct):
39 | d[strg] = _todict(elem)
40 | elif isinstance(elem, np.ndarray):
41 | d[strg] = _tolist(elem)
42 | else:
43 | d[strg] = elem
44 | return d
45 |
46 | def _tolist(ndarray):
47 | '''
48 | A recursive function which constructs lists from cellarrays
49 | (which are loaded as numpy ndarrays), recursing into the elements
50 | if they contain matobjects.
51 | '''
52 | elem_list = []
53 | for sub_elem in ndarray:
54 | if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct):
55 | elem_list.append(_todict(sub_elem))
56 | elif isinstance(sub_elem, np.ndarray):
57 | elem_list.append(_tolist(sub_elem))
58 | else:
59 | elem_list.append(sub_elem)
60 | return elem_list
61 | data = scipy.io.loadmat(filename, struct_as_record=False, squeeze_me=True)
62 | return _check_keys(data)
--------------------------------------------------------------------------------
/weights/null_token_1024.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EternalEvan/FlowIE/549bb871cff104c47fd126d23818f54d74ba1f20/weights/null_token_1024.pth
--------------------------------------------------------------------------------