├── .gitignore ├── LICENSE ├── README.md ├── configs ├── autoencoder │ └── vae-768-crop.yaml ├── demoire │ ├── cross-dataset │ │ ├── esdnet │ │ │ ├── fhdmi │ │ │ │ └── cd_unidemoire_esdnet_fhdmi.yaml │ │ │ ├── tip │ │ │ │ └── cd_unidemoire_esdnet_tip.yaml │ │ │ └── uhdm │ │ │ │ └── cd_unidemoire_esdnet_uhdm.yaml │ │ └── mbcnn │ │ │ ├── fhdmi │ │ │ └── cd_unidemoire_mbcnn_fhdmi.yaml │ │ │ ├── tip │ │ │ └── cd_unidemoire_mbcnn_tip.yaml │ │ │ └── uhdm │ │ │ └── cd_unidemoire_mbcnn_uhdm.yaml │ └── mhrnid │ │ ├── mhrnid_esdnet_unidemoire.yaml │ │ └── mhrnid_mbcnn_unidemoire.yaml ├── latent-diffusion │ └── ldm-vae-768-crop.yaml └── moire-blending │ ├── fhdmi │ └── blending_fhdmi.yaml │ ├── tip │ └── blending_tip.yaml │ └── uhdm │ └── blending_uhdm.yaml ├── environment.yaml ├── main.py ├── models ├── moire_blending │ ├── fhdmi │ │ └── config.yaml │ ├── tip │ │ └── config.yaml │ └── uhdm │ │ └── config.yaml └── moire_generator │ └── diffusion │ └── config.yaml ├── scripts └── sample_moire_pattern.py ├── setup.py ├── static └── images │ └── Pipeline.png ├── taming └── modules │ └── autoencoder │ └── lpips │ └── vgg.pth └── unidemoire ├── __init__.py ├── data ├── __init__.py ├── fhdmi.py ├── moire.py ├── moire_blend.py ├── tip.py ├── uhdm.py └── utils.py ├── lr_scheduler.py ├── models ├── MIB │ ├── Blending.py │ └── __init__.py ├── TRN │ ├── __init__.py │ └── model.py ├── autoencoder.py ├── cycle │ ├── Models │ │ ├── Loss_func_demoire.py │ │ ├── models.py │ │ ├── modules.py │ │ └── utils.py │ ├── nets.py │ └── networks.py ├── diffusion │ ├── __init__.py │ ├── classifier.py │ ├── ddim.py │ ├── ddpm.py │ └── plms.py ├── esdnet │ ├── __init__.py │ └── nets.py ├── mbcnn │ ├── LossNet.py │ ├── MBCNN.py │ ├── MBCNN_class.py │ ├── __init__.py │ └── arch_util.py ├── moire_blending.py ├── moire_nets.py ├── pmtnet │ ├── PMTNet.py │ └── __init__.py ├── shooting │ ├── __init__.py │ ├── image_transformer.py │ ├── method.py │ └── mosaicing_demosaicing_v2.py ├── undem │ ├── __init__.py │ └── model.py └── utils │ ├── __init__.py │ ├── common.py │ ├── loss_util.py │ ├── matlab_ssim.py │ └── metric.py ├── modules ├── attention.py ├── diffusionmodules │ ├── __init__.py │ ├── model.py │ ├── openaimodel.py │ └── util.py ├── distributions │ ├── __init__.py │ └── distributions.py ├── ema.py ├── encoders │ ├── __init__.py │ └── modules.py ├── losses │ ├── __init__.py │ ├── contperceptual.py │ └── vqperceptual.py └── x_transformer.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .python-version 5 | 6 | .vscode/ 7 | .idea/ 8 | *.swp 9 | *.swo 10 | 11 | .env 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | 16 | .coverage 17 | htmlcov/ 18 | .pytest_cache/ 19 | 20 | *.pt 21 | *.bin 22 | *.npy 23 | *.npz 24 | *.tmp 25 | *.ckpt 26 | 27 | .DS_Store 28 | Thumbs.db 29 | 30 | 31 | 32 | # Jupyter笔记本检查点 33 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 4DVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UniDemoiré: Towards Universal Image Demoiréing with Data Generation and Synthesis 2 | 3 |
Zemin Yang1, Yujing Sun2, Xidong Peng1, Siu Ming Yiu2, Yuexin Ma1
4 | 5 | ### [Project Page](https://yizhifengyeyzm.github.io/UniDemoire-page/) | [Paper](https://arxiv.org/abs/2502.06324) | [Dataset](https://drive.google.com/drive/folders/1k48jcgJLMUB0_42H-x1VYl67NP56zel8?usp=drive_link) 6 | 7 | *** 8 | 9 | The generalization ability of SOTA demoiréing models is greatly limited by the scarcity of data. Therefore, we mainly face two challenges to obtain a universal model with improved generalization capability: To obtain a vast amount of **1) diverse** and **2) realistic-looking moiré data**. Notice that traditional moiré image datasets contain real data, but continuously expanding their size to involve more diversity is extremely time-consuming and impractical. While current synthesized datasets/methods struggle to synthesize realistic-looking moiré images. 10 | 11 | ![Pipeline](./static/images/Pipeline.png) 12 | 13 | Hence, to tackle these challenges, we introduce a universal solution, **UniDemoiré**. The data diversity challenge is solved by collecting a more diverse moiré pattern dataset and presenting a moiré pattern generator to increase further pattern variations. Meanwhile, the data realistic-looking challenge is undertaken by a moiré image synthesis module. Finally, our solution can produce realistic-looking moiré images of sufficient diversity, substantially enhancing the zero-shot and cross-domain performance of demoiréing models. 14 | 15 | *** 16 | 17 | ## :hourglass_flowing_sand: To Do 18 | 19 | - [x] Release training code 20 | - [x] Release testing code 21 | - [x] Release dataset 22 | - [x] Release pre-trained models 23 | 24 | ## 🛠️ Enviroment 25 | The entire UniDemoiré framework is built on the Latent Diffusion Model and requires Python 3.8 and PyTorch-Lightning 1.4.2. 26 | You can install the UniDemoiré environment in the following two ways: 27 | ``` 28 | conda env create -f environment.yaml 29 | conda activate unidemoire 30 | ``` 31 | If the installation doesn't go well you can also follow the [instructions](https://github.com/CompVis/latent-diffusion?tab=readme-ov-file#requirements) to install the Latent Diffusion Model environment first, and then install the rest via pip: 32 | ``` 33 | conda activate unidemoire 34 | 35 | ... 36 | (install the ldm environment first) 37 | ... 38 | 39 | pip install colour-demosaicing==0.2.2 40 | pip install thop==0.1.1-2209072238 41 | pip install lpips==0.1.4 42 | pip install timm==0.9.16 43 | pip install pillow==9.5.0 44 | ``` 45 | 46 | ## 📦 Dataset and Pre-trained Models 47 | 48 | We provide the captured 4K moiré pattern dataset, the sampled moiré pattern dataset, the MHRNID dataset, and the pre-trained models on both Moiré Pattern Generator and Moiré Image Synthesis stages, which can be downloaded through the following links: 49 | 50 | **\[[Baidu Drive](https://pan.baidu.com/s/1YI4NO5xyC8oK3ZOFHpTa1w?pwd=sthx)\]** | **\[[Google Drive](https://drive.google.com/drive/folders/1k48jcgJLMUB0_42H-x1VYl67NP56zel8?usp=drive_link)\]** 51 | 52 | 53 | ## 🚀 Getting Started 54 | 55 | >**Some important tips about the training and testing process of our code:** 56 | 57 | The style of the config file is similar to [ldm](https://github.com/CompVis/latent-diffusion), and **the paths to the training/testing datasets can be changed inside config.** 58 | 59 | Logs and checkpoints for trained models are saved to `logs/_`. 60 | 61 | **If you need to continue training on a specific model, then you can simply run the training code with the “`-r`” parameter and add your model ckpt path** 62 | 63 | The dataset type and path for the test set need to be specified by you in the config file. **The program will automatically start the testing process after training is complete (same pattern as in Latent Diffusion Model)**. If you want to change the test dataset, you need to change the config file, and then re-run your training code with “`-r`” to continue training in the previous step, and the program will go directly to the test session! 64 | 65 | If you want to train with multiple gpus, remember to replace `` with your gpu id in the code template below, and be sure to adjust the “`--gpus`” parameter that follows it as well 66 | - For example: if you want to train with `4` gpus (assuming that they are numbered `5`, `6`, `7`, and `8`), then in the code template you should type `CUDA_VISIBLE_DEVICES=5,6,7,8` and with `--gpus 0,1,2,3,` 67 | 68 | ### Moiré Pattern Generator 69 | 70 | #### 1. AutoEncoder 71 | Configs for training a KL-regularized autoencoder on captured moiré pattern dataset are provided at `configs/autoencoder`. Training can be started by running: 72 | ``` 73 | CUDA_VISIBLE_DEVICES= python main.py --base configs/autoencoder/.yaml --scale_lr False -t --gpus 0, 74 | ``` 75 | After training, place the ckpt file in `models/moire_generator/autoencoder`. 76 | 77 | #### 2. Diffusion Model 78 | In `configs/latent-diffusion/` we provide configs for training diffusion on captured moiré pattern dataset. Training can be started by running: 79 | ``` 80 | CUDA_VISIBLE_DEVICES= python main.py --base configs/latent-diffusion/.yaml -t --gpus 0, 81 | ``` 82 | After training, place the ckpt file in `models/moire_generator/diffusion`. 83 | 84 | #### 3. Sampling 85 | Run the script via: 86 | ``` 87 | CUDA_VISIBLE_DEVICES= python scripts/sample_moire_pattern.py 88 | -r 90 | 91 | For example: 92 | CUDA_VISIBLE_DEVICES=0 python scripts/sample_moire_pattern.py 93 | -r ./models/moire_generator/diffusion/last.ckpt 94 | -n 10000 95 | ``` 96 | 97 | ### Moiré Image Synthesis 98 | In `configs/moire-blending/` we provide configs for training the synthesis model on the UHDM, FHDMi, and TIP datasets. Training can be started by running: 99 | ``` 100 | CUDA_VISIBLE_DEVICES= python main.py --base configs/moire-blending/.yaml --scale_lr False -t --gpus 0, 101 | 102 | For example: (training on UHDM dataset) 103 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/moire-blending/uhdm/blending_uhdm.yaml --scale_lr False -t --gpus 0, 104 | ``` 105 | where `` is one of {`uhdm/blending_uhdm`, `fhdmi/blending_fhdmi`, `tip/blending_tip`}. 106 | 107 | After training, place the ckpt file in `models/moire_blending//`. You can find the original config file in these paths. If you want to change the training cofig in the `configs/moire-blending/`, then you also need to change the config file in `models/moire_blending//` accordingly. 108 | 109 | ### Demoiréing 110 | 111 | #### 1. Zero-Shot Demoiréing 112 | First, download and unzip the MHRNID dataset. **(to be updated)** 113 | Then run the following code to start training on MHRNID: 114 | ``` 115 | CUDA_VISIBLE_DEVICES= python main.py --base configs/demoire/mhrnid/.yaml --scale_lr False -t --gpus 0, 116 | 117 | For example: (using ESDNet) 118 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/demoire/mhrnid/mhrnid_esdnet_unidemoire.yaml --scale_lr False -t --gpus 0, 119 | ``` 120 | where `` is one of {`mhrnid_esdnet_unidemoire`, `mhrnid_mbcnn_unidemoire`}. 121 | 122 | #### 2. Cross-Dataset Demoiréing 123 | 124 | Run the following code to start training: 125 | ``` 126 | CUDA_VISIBLE_DEVICES= python main.py --base configs/demoire/cross-dataset//.yaml --scale_lr False -t --gpus 0, 127 | 128 | For example: (using ESDNet, train on UHDM dataset) 129 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/demoire/cross-dataset/esdnet/cd_unidemoire_esdnet_uhdm.yaml --scale_lr False -t --gpus 0, 130 | ``` 131 | where `` is one of {`esdnet`, `mbcnn`}, and `` is one of {`uhdm`, `fhdmi`, `tip`}. 132 | 133 | 134 | 135 | ## 🙏 Acknowledgements 136 | 137 | We would like to express our gratitude to the authors and contributors of the following projects: 138 | 139 | - [Latent Diffusion Model](https://github.com/CompVis/latent-diffusion) 140 | - [UHDM](https://github.com/CVMI-Lab/UHDM) 141 | - [FHDMi](https://github.com/PKU-IMRE/FHDe2Net) 142 | - [TIP](https://github.com/ZhengJun-AI/MoirePhotoRestoration-MCNN) 143 | - [Uformer](https://github.com/ZhendongWang6/Uformer) 144 | - [UnDeM](https://github.com/zysxmu/UnDeM) 145 | 146 | 147 | 148 | ## 📑 Citation 149 | 150 | If you find our work useful, please consider citing us using the following BibTeX entry: 151 | 152 | ``` 153 | @misc{yang2025unidemoire, 154 | author = {Zemin Yang, Yujing Sun, Xidong Peng, Siu Ming Yiu, Yuexin Ma}, 155 | title = {UniDemoir\'e: Towards Universal Image Demoir\'eing with Data Generation and Synthesis}, 156 | year = {2025}, 157 | eprint = {2502.06324}, 158 | archivePrefix = {arXiv}, 159 | primaryClass = {cs.CV}, 160 | url={https://arxiv.org/abs/2502.06324}, 161 | } 162 | ``` 163 | 164 | -------------------------------------------------------------------------------- /configs/autoencoder/vae-768-crop.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: unidemoire.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: unidemoire.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | disc_in_channels: 3 14 | 15 | ddconfig: 16 | double_z: True 17 | z_channels: 64 18 | resolution: 768 19 | in_channels: 3 20 | out_ch: 3 21 | ch: 64 22 | ch_mult: [1,1,2,2,4,4] 23 | num_res_blocks: 2 24 | attn_resolutions: [16,8] 25 | dropout: 0.0 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 2 31 | wrap: True 32 | train: 33 | target: unidemoire.data.moire.MoirePattern 34 | params: 35 | dataset_path: "./data/captured_data" # Please set the path to your moire pattern dataset 36 | resolution: 768 37 | 38 | lightning: 39 | callbacks: 40 | image_logger: 41 | target: main.ImageLogger 42 | params: 43 | batch_frequency: 1000 44 | max_images: 8 45 | increase_log_steps: True 46 | 47 | trainer: 48 | benchmark: True 49 | accumulate_grad_batches: 2 50 | 51 | -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/esdnet/fhdmi/cd_unidemoire_esdnet_fhdmi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model # ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | model_name: ESDNet 6 | mode: COMBINE_ONLINE # COMBINE_ONLINE, COMBINE_ONLINE_ONLY, original 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_fhdmi.ckpt 9 | dataset: TIP 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: FHDMi 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 384 54 | paired: True 55 | mode: train 56 | 57 | # test: # UHDM 58 | # target: demoire.data.uhdm.uhdm_datasets 59 | # params: 60 | # args: 61 | # dataset_path: # Please set the path to your moire pattern dataset 62 | # LOADER: default 63 | # mode: test 64 | 65 | # test: # FHDMi 66 | # target: demoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | test: # TIP 74 | target: demoire.data.tip.tip_datasets 75 | params: 76 | args: 77 | dataset_path: # Please set the path to your moire pattern dataset 78 | LOADER: default 79 | mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 150 -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/esdnet/tip/cd_unidemoire_esdnet_tip.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: ESDNet 6 | mode: COMBINE_ONLINE 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt 9 | dataset: UHDM 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 8 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: TIP 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 256 54 | paired: True 55 | mode: train 56 | 57 | test: # UHDM 58 | target: unidemoire.data.uhdm.uhdm_datasets 59 | params: 60 | args: 61 | dataset_path: # Please set the path to your moire pattern dataset 62 | LOADER: default 63 | mode: test 64 | 65 | # test: # FHDMi 66 | # target: unidemoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | # test: # TIP 74 | # target: unidemoire.data.tip.tip_datasets 75 | # params: 76 | # args: 77 | # dataset_path: # Please set the path to your moire pattern dataset 78 | # LOADER: default 79 | # mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 70 -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/esdnet/uhdm/cd_unidemoire_esdnet_uhdm.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: ESDNet 6 | mode: COMBINE_ONLINE 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_uhdm.ckpt 9 | dataset: TIP 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: FHDMi 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 384 54 | paired: True 55 | mode: train 56 | 57 | # test: # UHDM 58 | # target: unidemoire.data.uhdm.uhdm_datasets 59 | # params: 60 | # args: 61 | # dataset_path: # Please set the path to your moire pattern dataset 62 | # LOADER: default 63 | # mode: test 64 | 65 | # test: # FHDMi 66 | # target: unidemoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | test: # TIP 74 | target: unidemoire.data.tip.tip_datasets 75 | params: 76 | args: 77 | dataset_path: # Please set the path to your moire pattern dataset 78 | LOADER: default 79 | mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 150 -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/mbcnn/fhdmi/cd_unidemoire_mbcnn_fhdmi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model # ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | model_name: MBCNN 6 | mode: COMBINE_ONLINE # COMBINE_ONLINE, COMBINE_ONLINE_ONLY, original 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_fhdmi.ckpt 9 | dataset: UHDM 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: FHDMi 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 384 54 | paired: True 55 | mode: train 56 | 57 | test: # UHDM 58 | target: unidemoire.data.uhdm.uhdm_datasets 59 | params: 60 | args: 61 | dataset_path: # Please set the path to your moire pattern dataset 62 | LOADER: default 63 | mode: test 64 | 65 | # test: # FHDMi 66 | # target: unidemoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | # test: # TIP 74 | # target: unidemoire.data.tip.tip_datasets 75 | # params: 76 | # args: 77 | # dataset_path: # Please set the path to your moire pattern dataset 78 | # LOADER: default 79 | # mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 150 -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/mbcnn/tip/cd_unidemoire_mbcnn_tip.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: demoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: MBCNN 6 | mode: COMBINE_ONLINE 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt 9 | dataset: UHDM 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 8 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: UHDM 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 256 54 | paired: True 55 | mode: train 56 | 57 | test: # UHDM 58 | target: unidemoire.data.uhdm.uhdm_datasets 59 | params: 60 | args: 61 | dataset_path: # Please set the path to your moire pattern dataset 62 | LOADER: default 63 | mode: test 64 | 65 | # test: # FHDMi 66 | # target: unidemoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | # test: # TIP 74 | # target: unidemoire.data.tip.tip_datasets 75 | # params: 76 | # args: 77 | # dataset_path: # Please set the path to your moire pattern dataset 78 | # LOADER: default 79 | # mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 70 -------------------------------------------------------------------------------- /configs/demoire/cross-dataset/mbcnn/uhdm/cd_unidemoire_mbcnn_uhdm.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: MBCNN 6 | mode: COMBINE_ONLINE 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_uhdm.ckpt 9 | dataset: TIP 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | network_config: 14 | # ESDNet 15 | en_feature_num: 48 16 | en_inter_num: 32 17 | de_feature_num: 64 18 | de_inter_num: 32 19 | sam_number: 2 # ESDNet:1, ESDNet-L:2 20 | 21 | # MBCNN 22 | n_filters: 64 23 | 24 | loss_config: 25 | # ESDNet 26 | LAM: 1 27 | LAM_P: 1 28 | 29 | optimizer_config: 30 | # ESDNet 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 4 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: UHDM 48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 50 | tip_dataset_path: # Please set the path to your moire pattern dataset 51 | moire_pattern_path: # Please set the path to your moire pattern dataset 52 | loader: crop 53 | crop_size: 384 54 | paired: True 55 | mode: train 56 | 57 | # test: # UHDM 58 | # target: unidemoire.data.uhdm.uhdm_datasets 59 | # params: 60 | # args: 61 | # dataset_path: # Please set the path to your moire pattern dataset 62 | # LOADER: default 63 | # mode: test 64 | 65 | # test: # FHDMi 66 | # target: unidemoire.data.fhdmi.fhdmi_datasets 67 | # params: 68 | # args: 69 | # dataset_path: # Please set the path to your moire pattern dataset 70 | # LOADER: default 71 | # mode: test 72 | 73 | test: # TIP 74 | target: unidemoire.data.tip.tip_datasets 75 | params: 76 | args: 77 | dataset_path: # Please set the path to your moire pattern dataset 78 | LOADER: default 79 | mode: test 80 | 81 | lightning: 82 | callbacks: 83 | image_logger: 84 | target: main.ImageLogger 85 | params: 86 | increase_log_steps: False 87 | rescale: False 88 | batch_frequency: 500 89 | max_images: 8 90 | 91 | trainer: 92 | benchmark: True 93 | # accumulate_grad_batches: 1 94 | max_epochs: 150 -------------------------------------------------------------------------------- /configs/demoire/mhrnid/mhrnid_esdnet_unidemoire.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: ESDNet 6 | mode: use_synthetic_moire_image_only 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt 9 | dataset: UHDM 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | ckpt_path: ## for testing 14 | 15 | network_config: 16 | # ESDNet 17 | en_feature_num: 48 18 | en_inter_num: 32 19 | de_feature_num: 64 20 | de_inter_num: 32 21 | sam_number: 2 # ESDNet:1, ESDNet-L:2 22 | 23 | # MBCNN 24 | n_filters: 64 25 | 26 | loss_config: 27 | LAM: 1 28 | LAM_P: 1 29 | 30 | optimizer_config: 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 8 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: MHRNID 48 | mhrnid_dataset_path: # Please set the path to your moire pattern dataset 49 | moire_pattern_path: # Please set the path to your moire pattern dataset 50 | loader: crop 51 | crop_size: 384 52 | paired: True 53 | mode: train 54 | 55 | test: # UHDM 56 | target: unidemoire.data.uhdm.uhdm_datasets 57 | params: 58 | args: 59 | dataset_path: # Please set the path to your moire pattern dataset 60 | LOADER: default 61 | mode: test 62 | 63 | # test: # FHDMi 64 | # target: unidemoire.data.fhdmi.fhdmi_datasets 65 | # params: 66 | # args: 67 | # dataset_path: # Please set the path to your moire pattern dataset 68 | # LOADER: default 69 | # mode: test 70 | 71 | 72 | lightning: 73 | callbacks: 74 | image_logger: 75 | target: main.ImageLogger 76 | params: 77 | increase_log_steps: False 78 | rescale: False 79 | batch_frequency: 500 80 | max_images: 8 81 | 82 | trainer: 83 | benchmark: True 84 | # accumulate_grad_batches: 1 85 | max_epochs: 50 -------------------------------------------------------------------------------- /configs/demoire/mhrnid/mhrnid_mbcnn_unidemoire.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-4 3 | target: unidemoire.models.moire_nets.Demoireing_Model 4 | params: 5 | model_name: MBCNN 6 | mode: use_synthetic_moire_image_only 7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem 8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt 9 | dataset: UHDM 10 | evaluation_time: False 11 | evaluation_metric: True 12 | save_img: True 13 | ckpt_path: ## for testing 14 | 15 | network_config: 16 | # ESDNet 17 | en_feature_num: 48 18 | en_inter_num: 32 19 | de_feature_num: 64 20 | de_inter_num: 32 21 | sam_number: 2 # ESDNet:1, ESDNet-L:2 22 | 23 | # MBCNN 24 | n_filters: 64 25 | 26 | loss_config: 27 | LAM: 1 28 | LAM_P: 1 29 | 30 | optimizer_config: 31 | beta1: 0.9 32 | beta2: 0.999 33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 35 | eta_min: 0.000001 36 | 37 | data: 38 | target: main.DataModuleFromConfig 39 | params: 40 | batch_size: 4 41 | num_workers: 8 42 | wrap: True 43 | train: 44 | target: unidemoire.data.moire_blend.moire_blending_datasets 45 | params: 46 | args: 47 | natural_dataset_name: MHRNID 48 | mhrnid_dataset_path: # Please set the path to your moire pattern dataset 49 | moire_pattern_path: # Please set the path to your moire pattern dataset 50 | loader: crop 51 | crop_size: 384 52 | paired: True 53 | mode: train 54 | 55 | test: # UHDM 56 | target: unidemoire.data.uhdm.uhdm_datasets 57 | params: 58 | args: 59 | dataset_path: # Please set the path to your moire pattern dataset 60 | LOADER: default 61 | mode: test 62 | 63 | # test: # FHDMi 64 | # target: unidemoire.data.fhdmi.fhdmi_datasets 65 | # params: 66 | # args: 67 | # dataset_path: # Please set the path to your moire pattern dataset 68 | # LOADER: default 69 | # mode: test 70 | 71 | 72 | lightning: 73 | callbacks: 74 | image_logger: 75 | target: main.ImageLogger 76 | params: 77 | increase_log_steps: False 78 | rescale: False 79 | batch_frequency: 500 80 | max_images: 8 81 | 82 | trainer: 83 | benchmark: True 84 | # accumulate_grad_batches: 1 85 | max_epochs: 50 -------------------------------------------------------------------------------- /configs/latent-diffusion/ldm-vae-768-crop.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 # 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: unidemoire.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 24 14 | channels: 64 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: unidemoire.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: unidemoire.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 24 33 | in_channels: 64 34 | out_channels: 64 35 | model_channels: 192 36 | attention_resolutions: [1, 2, 4, 8] 37 | num_res_blocks: 2 38 | channel_mult: [1,2,2,4] 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: unidemoire.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 64 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/moire_generator/autoencoder/last.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 64 52 | resolution: 768 53 | in_channels: 3 54 | out_ch: 3 55 | wide_scale_resolution: False 56 | ch: 64 57 | ch_mult: [1,1,2,2,4,4] 58 | num_res_blocks: 2 59 | attn_resolutions: [16,8] 60 | dropout: 0.0 61 | 62 | lossconfig: 63 | target: torch.nn.Identity 64 | 65 | cond_stage_config: "__is_unconditional__" 66 | 67 | data: 68 | target: main.DataModuleFromConfig 69 | params: 70 | batch_size: 2 71 | wrap: True 72 | train: 73 | target: unidemoire.data.moire.MoirePattern 74 | params: 75 | dataset_path: # Please set the path to your moire pattern dataset 76 | resolution: 768 77 | validation: 78 | target: unidemoire.data.moire.MoirePattern 79 | params: 80 | dataset_path: # Please set the path to your moire pattern dataset 81 | resolution: 768 82 | 83 | lightning: 84 | callbacks: 85 | image_logger: 86 | target: main.ImageLogger 87 | params: 88 | batch_frequency: 1000 89 | max_images: 8 90 | increase_log_steps: False 91 | 92 | 93 | trainer: 94 | benchmark: True 95 | # precision: 16 -------------------------------------------------------------------------------- /configs/moire-blending/fhdmi/blending_fhdmi.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | # MIB 9 | bl_method_1: multiply 10 | bl_method_1_op: 1.0 11 | bl_method_2: grain_merge 12 | bl_method_2_op: 0.8 13 | bl_final_weight_min: 0.65 14 | bl_final_weight_max: 0.75 15 | 16 | blending_network_args: 17 | # TRN 18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1] 19 | embed_dim: 16 20 | win_size: 8 21 | modulator: False 22 | shift_flag: False 23 | 24 | loss_config: 25 | LAM: 1 26 | LAM_P: 1 27 | 28 | optimizer_config: 29 | beta1: 0.9 30 | beta2: 0.999 31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 33 | eta_min: 0.000001 34 | 35 | data: 36 | target: main.DataModuleFromConfig 37 | params: 38 | batch_size: 2 39 | num_workers: 8 40 | wrap: True 41 | train: 42 | target: unidemoire.data.moire_blend.moire_blending_datasets 43 | params: 44 | args: 45 | natural_dataset_name: FHDMi 46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 48 | tip_dataset_path: # Please set the path to your moire pattern dataset 49 | moire_pattern_path: # Please set the path to your moire pattern dataset 50 | loader: crop 51 | crop_size: 384 52 | paired: True 53 | mode: train 54 | 55 | lightning: 56 | callbacks: 57 | image_logger: 58 | target: main.ImageLogger 59 | params: 60 | increase_log_steps: False 61 | rescale: False 62 | batch_frequency: 500 63 | max_images: 8 64 | 65 | trainer: 66 | benchmark: True 67 | accumulate_grad_batches: 1 68 | max_epochs: 25 -------------------------------------------------------------------------------- /configs/moire-blending/tip/blending_tip.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | # MIB 9 | bl_method_1: multiply 10 | bl_method_1_op: 1.0 11 | bl_method_2: grain_merge 12 | bl_method_2_op: 0.8 13 | bl_final_weight_min: 0.65 14 | bl_final_weight_max: 0.75 15 | 16 | blending_network_args: 17 | # TRN 18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1] 19 | embed_dim: 16 20 | win_size: 8 21 | modulator: False 22 | shift_flag: False 23 | 24 | loss_config: 25 | LAM: 1 26 | LAM_P: 1 27 | 28 | optimizer_config: 29 | beta1: 0.9 30 | beta2: 0.999 31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 33 | eta_min: 0.000001 34 | 35 | data: 36 | target: main.DataModuleFromConfig 37 | params: 38 | batch_size: 2 39 | num_workers: 4 40 | wrap: True 41 | train: 42 | target: unidemoire.data.moire_blend.moire_blending_datasets 43 | params: 44 | args: 45 | natural_dataset_name: TIP 46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 48 | tip_dataset_path: # Please set the path to your moire pattern dataset 49 | moire_pattern_path: # Please set the path to your moire pattern dataset 50 | loader: crop 51 | crop_size: 256 52 | paired: True 53 | mode: train 54 | 55 | lightning: 56 | callbacks: 57 | image_logger: 58 | target: main.ImageLogger 59 | params: 60 | increase_log_steps: False 61 | rescale: False 62 | batch_frequency: 500 63 | max_images: 8 64 | 65 | trainer: 66 | benchmark: True 67 | accumulate_grad_batches: 1 68 | max_epochs: 2 -------------------------------------------------------------------------------- /configs/moire-blending/uhdm/blending_uhdm.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | # MIB 9 | bl_method_1: multiply 10 | bl_method_1_op: 1.0 11 | bl_method_2: grain_merge 12 | bl_method_2_op: 0.8 13 | bl_final_weight_min: 0.65 14 | bl_final_weight_max: 0.75 15 | 16 | blending_network_args: 17 | # TRN 18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1] 19 | embed_dim: 16 20 | win_size: 8 21 | modulator: False 22 | shift_flag: False 23 | 24 | loss_config: 25 | LAM: 1 26 | LAM_P: 1 27 | 28 | optimizer_config: 29 | beta1: 0.9 30 | beta2: 0.999 31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then) 32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...) 33 | eta_min: 0.000001 34 | 35 | data: 36 | target: main.DataModuleFromConfig 37 | params: 38 | batch_size: 2 39 | num_workers: 4 40 | wrap: True 41 | train: 42 | target: unidemoire.data.moire_blend.moire_blending_datasets 43 | params: 44 | args: 45 | natural_dataset_name: UHDM 46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset 47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset 48 | tip_dataset_path: # Please set the path to your moire pattern dataset 49 | moire_pattern_path: # Please set the path to your moire pattern dataset 50 | loader: crop 51 | crop_size: 384 52 | paired: True 53 | mode: train 54 | 55 | lightning: 56 | callbacks: 57 | image_logger: 58 | target: main.ImageLogger 59 | params: 60 | increase_log_steps: False 61 | rescale: False 62 | batch_frequency: 500 63 | max_images: 8 64 | 65 | trainer: 66 | benchmark: True 67 | accumulate_grad_batches: 1 68 | max_epochs: 50 -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: unidemoire 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.0 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.4.2 19 | - omegaconf==2.1.1 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - torch-fidelity==0.3.0 24 | - transformers==4.3.1 25 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 26 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 27 | - -e . 28 | - colour-demosaicing==0.2.2 29 | - thop==0.1.1-2209072238 30 | - lpips==0.1.4 31 | - timm==0.9.16 32 | - pillow==9.5.0 33 | -------------------------------------------------------------------------------- /models/moire_blending/fhdmi/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.00001 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | bl_method_1: multiply 9 | bl_method_1_op: 1.0 10 | bl_method_2: grain_merge 11 | bl_method_2_op: 0.8 12 | bl_final_weight_min: 0.65 13 | bl_final_weight_max: 0.75 14 | blending_network_args: 15 | depths: [1,1,1,1,1,1,1,1,1] 16 | embed_dim: 16 17 | win_size: 8 18 | modulator: true 19 | shift_flag: false 20 | -------------------------------------------------------------------------------- /models/moire_blending/tip/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.00001 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | bl_method_1: multiply 9 | bl_method_1_op: 1.0 10 | bl_method_2: grain_merge 11 | bl_method_2_op: 0.8 12 | bl_final_weight_min: 0.65 13 | bl_final_weight_max: 0.75 14 | blending_network_args: 15 | depths: [1,1,1,1,1,1,1,1,1] 16 | embed_dim: 16 17 | win_size: 8 18 | modulator: true 19 | shift_flag: false 20 | -------------------------------------------------------------------------------- /models/moire_blending/uhdm/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.00001 3 | target: unidemoire.models.moire_blending.MoireBlending_Model 4 | params: 5 | model_name: UniDemoire 6 | network_config: 7 | init_blending_args: 8 | bl_method_1: multiply 9 | bl_method_1_op: 1.0 10 | bl_method_2: grain_merge 11 | bl_method_2_op: 0.8 12 | bl_final_weight_min: 0.65 13 | bl_final_weight_max: 0.75 14 | blending_network_args: 15 | depths: [1,1,1,1,1,1,1,1,1] 16 | embed_dim: 16 17 | win_size: 8 18 | modulator: true 19 | shift_flag: false 20 | -------------------------------------------------------------------------------- /models/moire_generator/diffusion/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: unidemoire.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 24 14 | # wide_scale_resolution: false 15 | channels: 64 16 | cond_stage_trainable: false 17 | concat_mode: false 18 | scale_by_std: true 19 | monitor: val/loss_simple_ema 20 | scheduler_config: 21 | target: unidemoire.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: 24 | - 10000 25 | cycle_lengths: 26 | - 10000000000000 27 | f_start: 28 | - 1.0e-06 29 | f_max: 30 | - 1.0 31 | f_min: 32 | - 1.0 33 | unet_config: 34 | target: unidemoire.modules.diffusionmodules.openaimodel.UNetModel 35 | params: 36 | image_size: 24 37 | in_channels: 64 38 | out_channels: 64 39 | model_channels: 192 40 | attention_resolutions: 41 | - 1 42 | - 2 43 | - 4 44 | - 8 45 | num_res_blocks: 2 46 | channel_mult: 47 | - 1 48 | - 2 49 | - 2 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: unidemoire.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 64 58 | monitor: val/rec_loss 59 | # VAE 模型路径 60 | ckpt_path: models/moire_generator/autoencoder/last.ckpt 61 | ddconfig: 62 | double_z: true 63 | z_channels: 64 64 | resolution: 768 65 | in_channels: 3 66 | out_ch: 3 67 | # wide_scale_resolution: false 68 | ch: 64 69 | ch_mult: 70 | - 1 71 | - 1 72 | - 2 73 | - 2 74 | - 4 75 | - 4 76 | num_res_blocks: 2 77 | attn_resolutions: 78 | - 16 79 | - 8 80 | dropout: 0.0 81 | lossconfig: 82 | target: torch.nn.Identity 83 | cond_stage_config: __is_unconditional__ 84 | data: 85 | target: main.DataModuleFromConfig 86 | params: 87 | batch_size: 2 88 | wrap: true 89 | train: 90 | target: unidemoire.data.moire.MoirePattern 91 | params: 92 | dataset_path: "/inspurfs/group/mayuexin/yangzemin/data/captured_data" 93 | resolution: 768 94 | validation: 95 | target: unidemoire.data.moire.MoirePattern 96 | params: 97 | dataset_path: "/inspurfs/group/mayuexin/yangzemin/data/captured_data" 98 | resolution: 768 99 | -------------------------------------------------------------------------------- /scripts/sample_moire_pattern.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob, datetime, yaml 2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 3 | 4 | import torch 5 | import time 6 | import numpy as np 7 | from tqdm import trange 8 | 9 | from omegaconf import OmegaConf 10 | from PIL import Image 11 | 12 | 13 | from unidemoire.models.diffusion.ddim import DDIMSampler 14 | from unidemoire.util import instantiate_from_config 15 | 16 | rescale = lambda x: (x + 1.) / 2. 17 | 18 | def custom_to_pil(x): 19 | x = x.detach().cpu() 20 | x = torch.clamp(x, -1., 1.) 21 | x = (x + 1.) / 2. 22 | x = x.permute(1, 2, 0).numpy() 23 | x = (255 * x).astype(np.uint8) 24 | if x.shape[2] == 1: 25 | x = x.squeeze() 26 | x = Image.fromarray(x, mode='L') 27 | else: 28 | x = Image.fromarray(x) 29 | x = x.convert("RGB") 30 | return x 31 | 32 | def custom_to_np(x): 33 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py 34 | sample = x.detach().cpu() 35 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) 36 | sample = sample.permute(0, 2, 3, 1) 37 | sample = sample.contiguous() 38 | return sample 39 | 40 | def logs2pil(logs, keys=["sample"]): 41 | imgs = dict() 42 | for k in logs: 43 | try: 44 | if len(logs[k].shape) == 4: 45 | img = custom_to_pil(logs[k][0, ...]) 46 | elif len(logs[k].shape) == 3: 47 | img = custom_to_pil(logs[k]) 48 | else: 49 | print(f"Unknown format for key {k}. ") 50 | img = None 51 | except: 52 | img = None 53 | imgs[k] = img 54 | return imgs 55 | 56 | @torch.no_grad() 57 | def convsample(model, shape, return_intermediates=True, 58 | verbose=True, 59 | make_prog_row=False): 60 | if not make_prog_row: 61 | return model.p_sample_loop(None, shape, 62 | return_intermediates=return_intermediates, verbose=verbose) 63 | else: 64 | return model.progressive_denoising( 65 | None, shape, verbose=True 66 | ) 67 | 68 | @torch.no_grad() 69 | def convsample_ddim(model, steps, shape, eta=1.0): 70 | ddim = DDIMSampler(model) 71 | bs = shape[0] 72 | shape = shape[1:] 73 | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) 74 | return samples, intermediates 75 | 76 | @torch.no_grad() 77 | def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): 78 | 79 | log = dict() 80 | shape = [batch_size, 81 | model.model.diffusion_model.in_channels, 82 | model.model.diffusion_model.image_size, 83 | model.model.diffusion_model.image_size] 84 | 85 | with model.ema_scope("Plotting"): 86 | t0 = time.time() 87 | if vanilla: 88 | sample, progrow = convsample(model, shape, 89 | make_prog_row=True) 90 | else: 91 | sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, 92 | eta=eta) 93 | t1 = time.time() 94 | 95 | x_sample = model.decode_first_stage(sample) 96 | log["sample"] = x_sample 97 | log["time"] = t1 - t0 98 | log['throughput'] = sample.shape[0] / (t1 - t0) 99 | return log 100 | 101 | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): 102 | if vanilla: 103 | print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') 104 | else: # Using DDIM sampling with 200 sampling steps and eta=1.0 105 | print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') 106 | 107 | tstart = time.time() 108 | n_saved = len(glob.glob(os.path.join(logdir,'*.png'))) 109 | 110 | if model.cond_stage_model is None: 111 | all_images = [] 112 | print(f"Running unconditional sampling for {n_samples} samples") 113 | for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): 114 | if n_saved >= n_samples: 115 | print(f'Finish after generating {n_saved} samples') 116 | break 117 | logs = make_convolutional_sample(model, batch_size=batch_size, 118 | vanilla=vanilla, custom_steps=custom_steps, 119 | eta=eta) 120 | n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") 121 | all_images.extend([custom_to_np(logs["sample"])]) 122 | 123 | else: 124 | raise NotImplementedError('Currently only sampling for unconditional models supported.') 125 | 126 | print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") 127 | 128 | def save_logs(logs, path, n_saved=0, key="sample", np_path=None): 129 | for k in logs: 130 | if k == key: 131 | batch = logs[key] 132 | if np_path is None: 133 | for x in batch: 134 | img = custom_to_pil(x) 135 | imgpath = os.path.join(path, f"{n_saved:07}.png") 136 | img.save(imgpath) 137 | n_saved += 1 138 | else: 139 | npbatch = custom_to_np(batch) 140 | shape_str = "x".join([str(x) for x in npbatch.shape]) 141 | nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz") 142 | np.savez(nppath, npbatch) 143 | n_saved += npbatch.shape[0] 144 | return n_saved 145 | 146 | def get_parser(): 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument( 149 | "-r", 150 | "--resume", 151 | type=str, 152 | nargs="?", 153 | help="load from logdir or checkpoint in logdir", 154 | default="./data/generated" 155 | ) 156 | parser.add_argument( 157 | "-n", 158 | "--n_samples", 159 | type=int, 160 | nargs="?", 161 | help="number of samples to draw", 162 | default=20 163 | ) 164 | parser.add_argument( 165 | "-e", 166 | "--eta", 167 | type=float, 168 | nargs="?", 169 | help="eta for ddim sampling (0.0 yields deterministic sampling)", 170 | default=1.0 171 | ) 172 | parser.add_argument( 173 | "-v", 174 | "--vanilla_sample", 175 | default=False, 176 | action='store_true', 177 | help="vanilla sampling (default option is DDIM sampling)?", 178 | ) 179 | parser.add_argument( 180 | "-l", 181 | "--logdir", 182 | type=str, 183 | nargs="?", 184 | help="extra logdir", 185 | default="./data/generated" 186 | ) 187 | parser.add_argument( 188 | "-c", 189 | "--custom_steps", 190 | type=int, 191 | nargs="?", 192 | help="number of steps for ddim and fastdpm sampling", 193 | default=200 194 | ) 195 | parser.add_argument( 196 | "--batch_size", 197 | type=int, 198 | nargs="?", 199 | help="the bs", 200 | default=1 201 | ) 202 | return parser 203 | 204 | def load_model_from_config(config, sd): 205 | model = instantiate_from_config(config) 206 | model.load_state_dict(sd,strict=False) 207 | model.cuda() 208 | model.eval() 209 | return model 210 | 211 | def load_model(config, ckpt, gpu, eval_mode): 212 | if ckpt: 213 | print(f"Loading model from {ckpt}") 214 | pl_sd = torch.load(ckpt, map_location="cpu") 215 | global_step = pl_sd["global_step"] 216 | else: 217 | pl_sd = {"state_dict": None} 218 | global_step = None 219 | model = load_model_from_config(config.model, 220 | pl_sd["state_dict"]) 221 | 222 | return model, global_step 223 | 224 | 225 | if __name__ == "__main__": 226 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 227 | sys.path.append(os.getcwd()) 228 | command = " ".join(sys.argv) 229 | print(75 * "=") 230 | parser = get_parser() 231 | opt, unknown = parser.parse_known_args() 232 | ckpt = None 233 | 234 | if not os.path.exists(opt.resume): 235 | raise ValueError("Cannot find {}".format(opt.resume)) 236 | 237 | if os.path.isfile(opt.resume): 238 | try: 239 | logdir = '/'.join(opt.resume.split('/')[:-1]) 240 | print(f'Logdir is {logdir}') 241 | except ValueError: 242 | paths = opt.resume.split("/") 243 | idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt 244 | logdir = "/".join(paths[:idx]) 245 | ckpt = opt.resume 246 | else: 247 | assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory" 248 | logdir = opt.resume.rstrip("/") 249 | ckpt = os.path.join(logdir, "model.ckpt") 250 | 251 | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) 252 | opt.base = base_configs 253 | 254 | configs = [OmegaConf.load(cfg) for cfg in opt.base] 255 | cli = OmegaConf.from_dotlist(unknown) 256 | config = OmegaConf.merge(*configs, cli) 257 | 258 | gpu = True 259 | eval_mode = True 260 | 261 | if opt.logdir != "none": 262 | locallog = logdir.split(os.sep)[-1] 263 | if locallog == "": locallog = logdir.split(os.sep)[-2] 264 | print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") 265 | logdir = os.path.join(opt.logdir, locallog) 266 | 267 | model, global_step = load_model(config, ckpt, gpu, eval_mode) 268 | print(f"global step: {global_step}") 269 | print(75 * "=") 270 | print("logging to:") 271 | logdir = os.path.join(logdir, now) 272 | imglogdir = os.path.join(logdir, "moire_patterns") 273 | os.makedirs(imglogdir) 274 | print(logdir) 275 | print(75 * "=") 276 | 277 | # write config out 278 | sampling_file = os.path.join(logdir, "sampling_config.yaml") 279 | sampling_conf = vars(opt) 280 | 281 | with open(sampling_file, 'w') as f: 282 | yaml.dump(sampling_conf, f, default_flow_style=False) 283 | 284 | run(model, imglogdir, eta=opt.eta, 285 | vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, 286 | batch_size=opt.batch_size) 287 | 288 | print("done.") -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='unidemoire', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /static/images/Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/static/images/Pipeline.png -------------------------------------------------------------------------------- /taming/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/taming/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /unidemoire/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/__init__.py -------------------------------------------------------------------------------- /unidemoire/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/data/__init__.py -------------------------------------------------------------------------------- /unidemoire/data/fhdmi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import cv2 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image 9 | from PIL import ImageFile 10 | import os 11 | 12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader 13 | 14 | 15 | class fhdmi_datasets(data.Dataset): 16 | def __init__(self, args, mode='train'): 17 | self.args = args 18 | self.mode = mode 19 | self.loader = args["LOADER"] 20 | self.image_list = sorted([file for file in os.listdir(self.args["dataset_path"] + '/target') if file.endswith('.png')]) 21 | 22 | def __getitem__(self, index): 23 | ImageFile.LOAD_TRUNCATED_IMAGES = True 24 | data = {} 25 | image_in_gt = self.image_list[index] 26 | number = image_in_gt[4:9] 27 | image_in = 'src_' + number + '.png' 28 | if self.mode == 'train': 29 | path_tar = self.args["dataset_path"] + '/target/' + image_in_gt 30 | path_src = self.args["dataset_path"] + '/source/' + image_in 31 | if self.loader == 'crop': 32 | x = random.randint(0, 1920 - self.args["CROP_SIZE"]) 33 | y = random.randint(0, 1080 - self.args["CROP_SIZE"]) 34 | labels, moire_imgs = crop_loader(self.args["CROP_SIZE"], x, y, [path_tar, path_src]) 35 | 36 | elif self.loader == 'resize': 37 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src]) 38 | data['origin_label'] = default_loader([path_tar])[0] 39 | 40 | elif self.loader == 'default': 41 | labels, moire_imgs = default_loader([path_tar, path_src]) 42 | 43 | elif self.mode == 'test': 44 | path_tar = self.args["dataset_path"] + '/target/' + image_in_gt 45 | path_src = self.args["dataset_path"] + '/source/' + image_in 46 | if self.loader == 'resize': 47 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src]) 48 | data['origin_label'] = default_loader([path_tar])[0] 49 | else: 50 | labels, moire_imgs = default_loader([path_tar, path_src]) 51 | 52 | else: 53 | print('Unrecognized mode! Please select either "train" or "test"') 54 | raise NotImplementedError 55 | 56 | data['in_img'] = moire_imgs 57 | data['label'] = labels 58 | data['number'] = number 59 | data['mode'] = self.mode 60 | return data 61 | 62 | def __len__(self): 63 | return len(self.image_list) -------------------------------------------------------------------------------- /unidemoire/data/moire.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from PIL import Image, ImageFilter, ImageEnhance 5 | 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms as transforms 9 | 10 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | def get_paths_from_images(path): 17 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 18 | images = [] 19 | for dirpath, _, fnames in sorted(os.walk(path)): 20 | for fname in sorted(fnames): 21 | if is_image_file(fname): 22 | img_path = os.path.join(dirpath, fname) 23 | images.append(img_path) 24 | assert images, '{:s} has no valid image file'.format(path) 25 | return sorted(images) 26 | 27 | 28 | class MoirePattern(Dataset): 29 | def __init__(self, dataset_path, resolution): 30 | self.resolution = resolution 31 | self.pil_to_tensor = transforms.ToTensor() 32 | self.dataset_path = dataset_path 33 | self.moire_layer_path = get_paths_from_images(self.dataset_path) 34 | 35 | def __len__(self): 36 | return len(self.moire_layer_path) 37 | 38 | def calculate_sharpness(self, image): 39 | image_gray = image.convert('L') 40 | image_laplace = image_gray.filter(ImageFilter.FIND_EDGES) 41 | sharpness = np.std(np.array(image_laplace)) 42 | return sharpness 43 | 44 | def calculate_colorfulness(self, image): 45 | image_lab = image.convert('LAB') 46 | l, a, b = image_lab.split() 47 | std_a = np.std(np.array(a)) 48 | std_b = np.std(np.array(b)) 49 | colorfulness = np.sqrt(std_a ** 2 + std_b ** 2) 50 | return colorfulness 51 | 52 | def calculate_image_quality(self, image): 53 | sharpness = self.calculate_sharpness(image) 54 | colorfulness = self.calculate_colorfulness(image) 55 | return sharpness, colorfulness 56 | 57 | def __getitem__(self, index): 58 | while(True): 59 | ## TODO: try different index moire patterns 60 | for i in range(3): 61 | ## TODO: [Multi crop] + [Sharpness & Colorfulness selection] 62 | img_moire_layer = Image.open(self.moire_layer_path[index]) 63 | self.transform_init() 64 | img_moire_layer = self.transform(img_moire_layer) 65 | sharpness, colorfulness = self.calculate_image_quality(img_moire_layer) 66 | if sharpness < 15 or colorfulness < 2.0: 67 | continue 68 | else: 69 | img_moire_layer = ImageEnhance.Contrast(img_moire_layer).enhance(2.0) 70 | img_moire_layer = self.pil_to_tensor(img_moire_layer) 71 | return { "image": img_moire_layer } 72 | index = random.randint(0, len(self.moire_layer_path) - 1) 73 | 74 | def transform_init(self): 75 | w = h = self.resolution 76 | base_transforms = [transforms.RandomHorizontalFlip(p=0.5),] 77 | 78 | q = random.randint(0, 2) 79 | r = random.randint(0, 1) 80 | if r == 0: # 4K crop into (w, h) 81 | extra_transforms = [transforms.RandomCrop(size=(h, w))] 82 | elif q == 0: # 4K to 2K, then crop into (w, h) 83 | extra_transforms = [transforms.Resize(size=(1440, 2560)), transforms.RandomCrop(size=(h, w))] 84 | elif q == 1: # 4K to 1080P, then crop into (w, h) 85 | extra_transforms = [transforms.Resize(size=(1080, 1920)), transforms.RandomCrop(size=(h, w))] 86 | elif q == 2: # 4K resize into (w, h) 87 | extra_transforms = [transforms.Resize(size=(h, w))] 88 | 89 | tran_transform = transforms.Compose(extra_transforms + base_transforms) 90 | # test_transform = transforms.Compose([transforms.Resize((h, w))] + base_transforms) 91 | self.transform = tran_transform -------------------------------------------------------------------------------- /unidemoire/data/tip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import cv2 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image 9 | from PIL import ImageFile 10 | import os 11 | 12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader 13 | 14 | class tip_datasets(data.Dataset): 15 | 16 | def __init__(self, args, mode='train'): 17 | 18 | data_path = args['dataset_path'] 19 | image_list = sorted([file for file in os.listdir(data_path + '/source') if file.endswith('.png')]) 20 | self.image_list = image_list 21 | self.args = args 22 | self.mode = mode 23 | t_list = [transforms.ToTensor()] 24 | self.composed_transform = transforms.Compose(t_list) 25 | 26 | def default_loader(self, path): 27 | return Image.open(path).convert('RGB') 28 | 29 | def __getitem__(self, index): 30 | ImageFile.LOAD_TRUNCATED_IMAGES = True 31 | data = {} 32 | image_in = self.image_list[index] 33 | image_in_gt = image_in[:-10] + 'target.png' 34 | number = image_in_gt[:-11] 35 | 36 | if self.mode == 'train': 37 | labels = self.default_loader(self.args['dataset_path'] + '/target/' + image_in_gt) 38 | moire_imgs = self.default_loader(self.args['dataset_path'] + '/source/' + image_in) 39 | 40 | w, h = labels.size 41 | i = random.randint(-6, 6) 42 | j = random.randint(-6, 6) 43 | labels = labels.crop((int(w / 6) + i, int(h / 6) + j, int(w * 5 / 6) + i, int(h * 5 / 6) + j)) 44 | moire_imgs = moire_imgs.crop((int(w / 6) + i, int(h / 6) + j, int(w * 5 / 6) + i, int(h * 5 / 6) + j)) 45 | 46 | labels = labels.resize((256, 256), Image.BILINEAR) 47 | moire_imgs = moire_imgs.resize((256, 256), Image.BILINEAR) 48 | 49 | elif self.mode == 'test': 50 | labels = self.default_loader(self.args['dataset_path'] + '/target/' + image_in_gt) 51 | moire_imgs = self.default_loader(self.args['dataset_path'] + '/source/' + image_in) 52 | 53 | w, h = labels.size 54 | labels = labels.crop((int(w / 6), int(h / 6), int(w * 5 / 6), int(h * 5 / 6))) 55 | moire_imgs = moire_imgs.crop((int(w / 6), int(h / 6), int(w * 5 / 6), int(h * 5 / 6))) 56 | 57 | labels = labels.resize((256, 256), Image.BILINEAR) 58 | moire_imgs = moire_imgs.resize((256, 256), Image.BILINEAR) 59 | 60 | 61 | else: 62 | print('Unrecognized mode! Please select either "train" or "test"') 63 | raise NotImplementedError 64 | 65 | moire_imgs = self.composed_transform(moire_imgs) 66 | labels = self.composed_transform(labels) 67 | 68 | data['in_img'] = moire_imgs 69 | data['label'] = labels 70 | data['number'] = number 71 | data['mode'] = self.mode 72 | return data 73 | 74 | def __len__(self): 75 | return len(self.image_list) -------------------------------------------------------------------------------- /unidemoire/data/uhdm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import cv2 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image 9 | from PIL import ImageFile 10 | import os 11 | 12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader 13 | 14 | 15 | class uhdm_datasets(data.Dataset): 16 | 17 | def __init__(self, args, mode='train'): 18 | self.args = args 19 | self.mode = mode 20 | self.loader = args["LOADER"] 21 | self.image_list = self._list_image_files_recursively(data_dir=self.args["dataset_path"]) 22 | 23 | def _list_image_files_recursively(self, data_dir): 24 | file_list = [] 25 | for home, dirs, files in os.walk(data_dir): 26 | for filename in files: 27 | if filename.endswith('gt.jpg'): 28 | file_list.append(os.path.join(home, filename)) 29 | file_list.sort() 30 | return file_list 31 | 32 | def __getitem__(self, index): 33 | ImageFile.LOAD_TRUNCATED_IMAGES = True 34 | data = {} 35 | path_tar = self.image_list[index] 36 | number = os.path.split(path_tar)[-1][0:4] 37 | path_src = os.path.split(path_tar)[0] + '/' + os.path.split(path_tar)[-1][0:4] + '_moire.jpg' 38 | if self.mode == 'train': 39 | if self.loader == 'crop': 40 | if os.path.split(path_tar)[0][-5:-3] == 'mi': 41 | w = 4624 42 | h = 3472 43 | else: 44 | w = 4032 45 | h = 3024 46 | x = random.randint(0, w - self.args["CROP_SIZE"]) 47 | y = random.randint(0, h - self.args["CROP_SIZE"]) 48 | labels, moire_imgs = crop_loader(self.args["CROP_SIZE"], x, y, [path_tar, path_src]) 49 | 50 | elif self.loader == 'resize_then_crop': 51 | labels, moire_imgs = resize_then_crop_loader(self.args["CROP_SIZE"], self.args["RESIZE_SIZE"], [path_tar, path_src]) 52 | data['origin_label'] = default_loader([path_tar])[0] 53 | 54 | elif self.loader == 'resize': 55 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src]) 56 | data['origin_label'] = default_loader([path_tar])[0] 57 | 58 | elif self.loader == 'default': 59 | labels, moire_imgs = default_loader([path_tar, path_src]) 60 | 61 | elif self.mode == 'test': 62 | if self.loader == 'resize': 63 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src]) 64 | data['origin_label'] = default_loader([path_tar])[0] 65 | elif self.loader == 'resize_then_crop': 66 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src]) 67 | data['origin_label'] = default_loader([path_tar])[0] 68 | else: 69 | labels, moire_imgs = default_loader([path_tar, path_src]) 70 | 71 | else: 72 | print('Unrecognized mode! Please select either "train" or "test"') 73 | raise NotImplementedError 74 | 75 | data['in_img'] = moire_imgs 76 | data['label'] = labels 77 | data['number'] = number 78 | 79 | data['mode'] = self.mode 80 | 81 | return data 82 | 83 | def __len__(self): 84 | # return 10 # debug 85 | return len(self.image_list) -------------------------------------------------------------------------------- /unidemoire/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | 6 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 11 | 12 | def get_natural_image_list_and_moire_pattern_list(args, mode='train', add_clean_percent=1.0): 13 | moire_pattern_files = _list_moire_pattern_files_recursively(data_dir=args["moire_pattern_path"]) 14 | if args.natural_dataset_name == 'UHDM': 15 | uhdm_natural_files = _list_image_files_recursively(data_dir=args["uhdm_dataset_path"]) 16 | return uhdm_natural_files, moire_pattern_files 17 | 18 | elif args.natural_dataset_name == 'FHDMi': 19 | fhdmi_natural_files = sorted([file for file in os.listdir(args["fhdmi_dataset_path"] + '/target') if file.endswith('.png')]) 20 | return fhdmi_natural_files, moire_pattern_files 21 | 22 | elif args.natural_dataset_name == 'TIP': 23 | tip_natural_files = sorted([file for file in os.listdir(args["tip_dataset_path"] + '/source') if file.endswith('.png')]) 24 | return tip_natural_files, moire_pattern_files 25 | 26 | elif args.natural_dataset_name == 'AIM': 27 | if mode=='train': 28 | aim_natural_files = sorted([file for file in os.listdir(args["aim_dataset_path"] + '/moire') if file.endswith('.jpg')]) 29 | else: 30 | aim_natural_files = sorted([file for file in os.listdir(args["aim_dataset_path"] + '/moire') if file.endswith('.png')]) 31 | return aim_natural_files, moire_pattern_files 32 | 33 | elif args.natural_dataset_name == 'MHRNID': 34 | mhrnid_files = get_paths_from_images(path=args["mhrnid_dataset_path"]) 35 | return mhrnid_files, moire_pattern_files 36 | 37 | elif args.natural_dataset_name == 'UHDM and FHDMi': 38 | uhdm_natural_files = _list_image_files_recursively(data_dir=args["uhdm_dataset_path"]) 39 | fhdmi_natural_files = sorted([file for file in os.listdir(args["fhdmi_dataset_path"] + '/target') if file.endswith('.png')]) 40 | 41 | print(f'Clean image percentage: {add_clean_percent*100}%') 42 | fhdmi_size = len(fhdmi_natural_files) 43 | fhdmi_sublist_size = int(fhdmi_size * add_clean_percent) 44 | fhdmi_sublist_files = fhdmi_natural_files[:fhdmi_sublist_size] 45 | 46 | return uhdm_natural_files + fhdmi_sublist_files, moire_pattern_files 47 | 48 | else: 49 | print('Unrecognized data_type!') 50 | raise NotImplementedError 51 | 52 | 53 | def get_unpaired_moire_images(args): 54 | if args.unpaired_real_moire_dataset == 'TIP': 55 | tip_real_moire_files = sorted([file for file in os.listdir(args["tip_dataset_path"] + '/source') if file.endswith('.png')]) 56 | return tip_real_moire_files 57 | else: 58 | print('Unrecognized data_type!') 59 | raise NotImplementedError 60 | 61 | 62 | def _list_image_files_recursively(data_dir): 63 | file_list = [] 64 | for home, dirs, files in os.walk(data_dir): 65 | for filename in files: 66 | if filename.endswith('gt.jpg'): 67 | file_list.append(os.path.join(home, filename)) 68 | file_list.sort() 69 | return file_list 70 | 71 | def get_paths_from_images(path): 72 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 73 | images = [] 74 | for dirpath, _, fnames in sorted(os.walk(path)): 75 | for fname in sorted(fnames): 76 | if is_image_file(fname): 77 | img_path = os.path.join(dirpath, fname) 78 | images.append(img_path) 79 | assert images, '{:s} has no valid image file'.format(path) 80 | return sorted(images) 81 | 82 | def _list_moire_pattern_files_recursively(data_dir): 83 | assert os.path.isdir(data_dir), '{:s} is not a valid directory'.format(data_dir) 84 | images = [] 85 | for dirpath, _, fnames in sorted(os.walk(data_dir)): 86 | for fname in sorted(fnames): 87 | if is_image_file(fname): 88 | img_path = os.path.join(dirpath, fname) 89 | images.append(img_path) 90 | assert images, '{:s} has no valid image file'.format(data_dir) 91 | return sorted(images) 92 | 93 | 94 | 95 | def default_loader(path_set=[]): 96 | imgs = [] 97 | for path in path_set: 98 | img = Image.open(path).convert('RGB') 99 | img = default_toTensor(img) 100 | imgs.append(img) 101 | 102 | return imgs 103 | 104 | def crop_loader(crop_size, x, y, path_set=[]): 105 | imgs = [] 106 | for path in path_set: 107 | img = Image.open(path).convert('RGB') 108 | img = img.crop((x, y, x + crop_size, y + crop_size)) 109 | img = default_toTensor(img) 110 | imgs.append(img) 111 | return imgs 112 | 113 | def resize_loader(resize_size, path_set=[]): 114 | imgs = [] 115 | for path in path_set: 116 | img = Image.open(path).convert('RGB') 117 | img = img.resize((resize_size,resize_size),Image.BICUBIC) 118 | img = default_toTensor(img) 119 | imgs.append(img) 120 | 121 | return imgs 122 | 123 | def resize_then_crop_loader(crop_size, resize_size, path_set=[]): 124 | imgs = [] 125 | for path in path_set: 126 | img = Image.open(path).convert('RGB') 127 | if resize_size == 1920: 128 | img = img.resize((1920,1080),Image.BICUBIC) 129 | x = random.randint(0, 1920 - crop_size) 130 | y = random.randint(0, 1080 - crop_size) 131 | else: 132 | img = img.resize((resize_size,resize_size),Image.BICUBIC) 133 | x = random.randint(0, resize_size - crop_size) 134 | y = random.randint(0, resize_size - crop_size) 135 | img = img.crop((x, y, x + crop_size, y + crop_size)) 136 | img = default_toTensor(img) 137 | imgs.append(img) 138 | return imgs 139 | 140 | 141 | def default_toTensor(img): 142 | t_list = [transforms.ToTensor()] 143 | composed_transform = transforms.Compose(t_list) 144 | return composed_transform(img) -------------------------------------------------------------------------------- /unidemoire/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /unidemoire/models/MIB/Blending.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Blending(nn.Module): 7 | def __init__(self, args): 8 | super(Blending, self).__init__() 9 | self.args = args 10 | self.final_weight_range = (self.args["bl_final_weight_min"], self.args["bl_final_weight_max"]) 11 | self.bl_method_1_op = torch.Tensor([self.args["bl_method_1_op"]]) 12 | self.bl_method_2_op = torch.Tensor([self.args["bl_method_2_op"]]) 13 | 14 | def forward(self, img_background, img_foreground): 15 | # bs,c,h,w = img_background.shape 16 | self.device = img_background.device 17 | 18 | self.img_background = self.RGB_to_RGBA(img_background) 19 | self.img_foreground = self.RGB_to_RGBA(img_foreground) 20 | 21 | img_result_1 = self.get_blending_result(method=self.args["bl_method_1"], opacity=self.bl_method_1_op) 22 | img_result_2 = self.get_blending_result(method=self.args["bl_method_2"], opacity=self.bl_method_2_op) 23 | self.weight = torch.Tensor([random.uniform(*self.final_weight_range)]).to(self.device) 24 | result = img_result_1 * self.weight + img_result_2 * (1 - self.weight) 25 | 26 | return result, self.weight 27 | 28 | def init_from_ckpt(self, path, ignore_keys=list()): 29 | sd = torch.load(path, map_location="cpu")["state_dict"] 30 | keys = list(sd.keys()) 31 | for k in keys: 32 | for ik in ignore_keys: 33 | if k.startswith(ik): 34 | print("Deleting key {} from state_dict.".format(k)) 35 | del sd[k] 36 | self.load_state_dict(sd, strict=False) 37 | print(f"MIB Module Restored from {path}, weight = {self.mib_weight}") 38 | 39 | def RGBA_to_RGB(self, image): 40 | return image[:,:3,:,:] 41 | 42 | def RGB_to_RGBA(self, image): 43 | b, c, w, h = image.shape 44 | img = torch.ones([b, c + 1, w, h]).to(self.device) 45 | img[:,:3,:,:] = image 46 | 47 | return img 48 | 49 | def soft_light(self): 50 | """ 51 | if A ≤ 0.5: C = (2A-1)(B-B^2) + B 52 | if A > 0.5: C = (2A-1)(sqrt(B)-B) + B 53 | """ 54 | A = self.img_foreground[:, :3, :, :] 55 | B = self.img_background[:, :3, :, :] 56 | C = torch.where(A <= 0.5, 57 | (2 * A - 1.0)*(B - torch.pow(B,2)) + B, 58 | (2 * A - 1.0)*(torch.sqrt(B) - B) + B 59 | ) 60 | return C 61 | 62 | def hard_light(self): 63 | """ 64 | if A ≤ 0.5: C = 2*A*B 65 | if A > 0.5: C = 1-2*(1-A)(1-B) 66 | """ 67 | A = self.img_foreground[:, :3, :, :] 68 | B = self.img_background[:, :3, :, :] 69 | C = torch.where(A <= 0.5, 70 | 2 * A * B, 71 | 1 - 2 * (1.0 - A)*(1.0 - B) 72 | ) 73 | return C 74 | 75 | def lighten(self): 76 | """ 77 | if B ≤ A: C = A 78 | if B > A: C = B 79 | """ 80 | A = self.img_foreground[:, :3, :, :] 81 | B = self.img_background[:, :3, :, :] 82 | C = torch.maximum(A, B) 83 | return C 84 | 85 | def darken(self): 86 | """ 87 | if B ≤ A: C = B 88 | if B > A: C = A 89 | """ 90 | A = self.img_foreground[:, :3, :, :] 91 | B = self.img_background[:, :3, :, :] 92 | C = torch.minimum(A, B) 93 | return C 94 | 95 | def multiply(self): 96 | """ 97 | C = A * B 98 | """ 99 | A = self.img_foreground[:, :3, :, :] 100 | B = self.img_background[:, :3, :, :] 101 | C = A * B 102 | return C 103 | 104 | def grain_merge(self): 105 | """ 106 | C = A + B - 0.5 107 | """ 108 | A = self.img_foreground[:, :3, :, :] 109 | B = self.img_background[:, :3, :, :] 110 | C = A + B - 0.5 111 | return C 112 | 113 | def _compose_alpha(self, opacity): 114 | comp = self.img_foreground[:,3,:,:] 115 | 116 | comp_alpha = comp * opacity 117 | new_alpha = comp_alpha + (1.0 - comp_alpha) * self.img_background[:,3,:,:] 118 | 119 | ratio = comp_alpha / new_alpha 120 | ratio[torch.isnan(ratio)] = 0.0 121 | ratio[torch.isinf(ratio)] = 0.0 122 | 123 | return ratio 124 | 125 | def get_blending_result(self, method, opacity): 126 | opacity = opacity.to(self.device) 127 | ratio = self._compose_alpha(opacity) 128 | comp = torch.clip(getattr(self, method)(), 0.0, 1.0) 129 | ratio_rs = torch.stack([ratio,ratio,ratio],dim=1).to(self.device) 130 | img_out = comp * ratio_rs + self.img_background[:,:3,:,:] * (1.0 - ratio_rs) 131 | 132 | alpha_channel = self.img_background[:,3,:,:] 133 | alpha_channel = alpha_channel.unsqueeze(dim=1) 134 | img_out = torch.nan_to_num(torch.cat((img_out, alpha_channel),dim=1)) # add alpha channel and replace nans 135 | 136 | return self.RGBA_to_RGB(img_out).to(self.device) 137 | -------------------------------------------------------------------------------- /unidemoire/models/MIB/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/MIB/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/TRN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/TRN/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/cycle/Models/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import datetime 4 | import sys 5 | 6 | from torch.autograd import Variable 7 | import torch 8 | import numpy as np 9 | 10 | import torch.nn as nn 11 | from torchvision.utils import save_image 12 | from math import log10, exp, sqrt, cos, pi 13 | import torch.nn.functional as F 14 | 15 | class ReplayBuffer: 16 | def __init__(self, max_size=50): 17 | assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful." 18 | self.max_size = max_size 19 | self.data = [] 20 | 21 | def push_and_pop(self, data): 22 | to_return = [] 23 | for element in data.data: 24 | element = torch.unsqueeze(element, 0) 25 | if len(self.data) < self.max_size: 26 | self.data.append(element) 27 | to_return.append(element) 28 | else: 29 | if random.uniform(0, 1) > 0.5: 30 | i = random.randint(0, self.max_size - 1) 31 | to_return.append(self.data[i].clone()) 32 | self.data[i] = element 33 | else: 34 | to_return.append(element) 35 | return Variable(torch.cat(to_return)) 36 | 37 | 38 | class LambdaLR: 39 | def __init__(self, n_epochs, offset, decay_start_epoch): 40 | assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!" 41 | self.n_epochs = n_epochs 42 | self.offset = offset 43 | self.decay_start_epoch = decay_start_epoch 44 | 45 | def step(self, epoch): 46 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch) 47 | 48 | ## DCT Transform 49 | class DCT(nn.Module): 50 | def __init__(self): 51 | super(DCT, self).__init__() 52 | 53 | conv_shape = (1, 1, 64, 64) 54 | kernel = np.zeros(conv_shape) 55 | r1 = sqrt(1.0/8) 56 | r2 = sqrt(2.0/8) 57 | for i in range(8): 58 | _u = 2*i+1 59 | for j in range(8): 60 | _v = 2*j+1 61 | index = i*8+j 62 | for u in range(8): 63 | for v in range(8): 64 | index2 = u*8+v 65 | t = cos(_u*u*pi/16)*cos(_v*v*pi/16) 66 | t = t*r1 if u==0 else t*r2 67 | t = t*r1 if v==0 else t*r2 68 | kernel[0,0,index2,index] = t 69 | 70 | self.kernel = torch.tensor(kernel, requires_grad = False, dtype=torch.float32) 71 | 72 | def forward(self, inputs): 73 | 74 | device = inputs.device 75 | kernel = self.kernel.to(device) 76 | k = kernel.permute(3, 1, 2, 0) 77 | k = torch.reshape(k, (64, 1, 8, 8)) 78 | 79 | b, c, h, w = inputs.size() 80 | scale_r = h // 8 81 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8)) 82 | 83 | outputs = torch.zeros_like(new_inputs) 84 | 85 | num_of_p = c * scale_r * scale_r 86 | 87 | for i in range(num_of_p): 88 | patch = new_inputs[:, i, :, :] 89 | patch = patch.unsqueeze(dim=1) 90 | patch = patch.to(device).float() 91 | 92 | new_patch = F.conv2d(patch, k, stride = 8) 93 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim = 1) 94 | 95 | outputs[:, i, :, :] = new_patch 96 | 97 | outputs = torch.reshape(outputs, (b, c, h, w)) 98 | 99 | return outputs 100 | 101 | class Local_DCT(nn.Module): 102 | def __init__(self): 103 | super(Local_DCT, self).__init__() 104 | 105 | conv_shape = (1, 1, 64, 64) 106 | kernel = np.zeros(conv_shape) 107 | r1 = sqrt(1.0 / 8) 108 | r2 = sqrt(2.0 / 8) 109 | for i in range(8): 110 | _u = 2 * i + 1 111 | for j in range(8): 112 | _v = 2 * j + 1 113 | index = i * 8 + j 114 | for u in range(8): 115 | for v in range(8): 116 | index2 = u * 8 + v 117 | t = cos(_u * u * pi / 16) * cos(_v * v * pi / 16) 118 | t = t * r1 if u == 0 else t * r2 119 | t = t * r1 if v == 0 else t * r2 120 | kernel[0, 0, index2, index] = t 121 | 122 | self.kernel = torch.tensor(kernel, requires_grad=False, dtype=torch.float32) 123 | 124 | def forward(self, inputs): 125 | 126 | device = inputs.device 127 | kernel = self.kernel.to(device) 128 | k = kernel.permute(3, 1, 2, 0) 129 | k = torch.reshape(k, (64, 1, 8, 8)) 130 | 131 | b, c, h, w = inputs.size() 132 | scale_r = h // 8 133 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8)) 134 | 135 | outputs = torch.zeros_like(new_inputs) 136 | 137 | num_of_p = c * scale_r * scale_r 138 | 139 | for i in range(num_of_p): 140 | patch = new_inputs[:, i, :, :] 141 | patch = patch.unsqueeze(dim=1) 142 | patch = patch.to(device).float() 143 | 144 | new_patch = F.conv2d(patch, k, stride=8) 145 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim=1) 146 | 147 | outputs[:, i, :, :] = new_patch 148 | 149 | outputs = torch.reshape(outputs, (b, c, h, w)) 150 | 151 | return outputs 152 | 153 | class Inverse_DCT(nn.Module): 154 | def __init__(self): 155 | super(Inverse_DCT, self).__init__() 156 | 157 | conv_shape = (1, 1, 64, 64) 158 | kernel = np.zeros(conv_shape) 159 | r1 = sqrt(1.0/8) 160 | r2 = sqrt(2.0/8) 161 | for i in range(8): 162 | _u = 2*i+1 163 | for j in range(8): 164 | _v = 2*j+1 165 | index = i*8+j 166 | for u in range(8): 167 | for v in range(8): 168 | index2 = u*8+v 169 | t = cos(_u*u*pi/16)*cos(_v*v*pi/16) 170 | t = t*r1 if u==0 else t*r2 171 | t = t*r1 if v==0 else t*r2 172 | kernel[0,0,index2,index] = t 173 | 174 | self.kernel = torch.tensor(kernel, requires_grad = False, dtype=torch.float32) 175 | 176 | self.kernel = self.kernel.permute(0, 1, 3, 2) 177 | 178 | def forward(self, inputs): 179 | 180 | device = inputs.device 181 | kernel = self.kernel.to(device) 182 | k = kernel.permute(3, 1, 2, 0) 183 | k = torch.reshape(k, (64, 1, 8, 8)) 184 | 185 | b, c, h, w = inputs.size() 186 | scale_r = h // 8 187 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8)) 188 | 189 | outputs = torch.zeros_like(new_inputs) 190 | 191 | num_of_p = c * scale_r * scale_r 192 | 193 | for i in range(num_of_p): 194 | patch = new_inputs[:, i, :, :] 195 | patch = patch.unsqueeze(dim=1) 196 | patch = patch.to(device).float() 197 | 198 | new_patch = F.conv2d(patch, k, stride = 8) 199 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim = 1) 200 | 201 | outputs[:, i, :, :] = new_patch 202 | 203 | outputs = torch.reshape(outputs, (b, c, h, w)) 204 | 205 | return outputs.clamp(min = 0, max = 1) 206 | 207 | ## block wise mapping 208 | def block_wise_mapping(net, input, input_size, pad): 209 | b, c, _, _ = input.size() 210 | window = create_window(pad, b, c, pad // 2) 211 | 212 | pad_in = padarray(input, pad) 213 | 214 | pad_out = torch.zeros_like(pad_in) 215 | pnorm = torch.zeros_like(pad_in) 216 | 217 | device = input.device 218 | 219 | i = 0 220 | j = 0 221 | 222 | stride = pad // 2 223 | 224 | _,_, height, width = pad_in.size() 225 | 226 | while(i < height - input_size + 1): 227 | while(j < width - input_size + 1): 228 | patch = pad_in[:,:,i : i + input_size, j : j + input_size] 229 | patch = patch.to(device).float() 230 | 231 | pout = net(patch) 232 | 233 | if i < height - input_size and j < width - input_size: 234 | pout = pout[:,:,0 : 0 + pad, 0 : 0 + pad] 235 | 236 | mask = window.to(device) 237 | p_after = pout * mask 238 | 239 | pad_out[:,:,i : i + pad, j : j + pad] = pad_out[:,:,i : i + pad, j : j + pad] + p_after 240 | pnorm[:,:,i : i + pad, j : j + pad] = pnorm[:,:,i : i + pad, j : j + pad] + mask 241 | else: 242 | pad_out[:, :, i : i + input_size, j : j + input_size] = pad_out[:, :, i : i + input_size, j : j + input_size] + pout 243 | pnorm[:, :, i : i + input_size, j : j + input_size] = pnorm[:, :, i : i + input_size, j : j + input_size] + 1.0 244 | 245 | j = j + stride 246 | 247 | i = i + stride 248 | j = 0 249 | 250 | output = pad_out[:,:,0 : 1024, 0 : 1024] / pnorm[:,:,0 : 1024, 0 : 1024] 251 | 252 | return output 253 | 254 | def create_window(window_size, batch, channel, sigma): 255 | _1D_window = gaussian(window_size, sigma).unsqueeze(1) 256 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 257 | window = _2D_window.expand(batch, channel, window_size, window_size) 258 | return window 259 | 260 | def padarray(input, size_pad): 261 | b,c,h,w = input.size() 262 | device = input.device 263 | 264 | new_h = h + size_pad 265 | new_w = w + size_pad 266 | output = torch.zeros((b, c, new_h, new_w)).to(device) 267 | 268 | output[:,:,0 : h, 0 : w] = input[:,:,:,:] 269 | # output[:,:,h : new_h, w : new_w] = 0.0 270 | 271 | return output 272 | 273 | def gaussian(window_size, sigma): 274 | gauss = torch.Tensor([exp(-(x - window_size//2)**2 / float(2*sigma**2)) for x in range(window_size)]) 275 | return gauss / gauss.sum() -------------------------------------------------------------------------------- /unidemoire/models/cycle/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | 6 | from Models import GM2_UNet5_256, GM2_UNet5_128, GM2_UNet5_64, TMB, Discriminator, L1_ASL 7 | 8 | class CycleModel(nn.Module): 9 | def __init__(self): 10 | super(CycleModel, self).__init__() 11 | 12 | self.resolution_dict = { 13 | '256': {'Net_Demoire':'256', 'G_Artifact':'256_2', 'D_moire':'256', 'D_clear':'256'}, 14 | '128': {'Net_Demoire':'128', 'G_Artifact':'128_2', 'D_moire':'128', 'D_clear':'128'}, 15 | '64': {'Net_Demoire':'64', 'G_Artifact':['64_2', '64_1'], 'D_moire':'64', 'D_clear':'64'}, 16 | } 17 | 18 | self.Net_Demoire = { 19 | '256': GM2_UNet5_256(6, 3), 20 | '128': GM2_UNet5_128(6, 3), 21 | '64': GM2_UNet5_64(3, 3), 22 | 'TMB': TMB(256, 1) 23 | } 24 | 25 | self.G_Artifact = { 26 | '256_2': GM2_UNet5_256(6, 3), 27 | '128_2': GM2_UNet5_128(6, 3), 28 | '64_2': GM2_UNet5_64(3, 3), 29 | '64_1': TMB(256, 1), 30 | } 31 | 32 | self.D_moire = { 33 | '256': Discriminator(6, 256, 256), 34 | '128': Discriminator(6, 128, 128), 35 | '64': Discriminator(6, 64, 64), 36 | } 37 | 38 | self.D_clear = { 39 | '256': Discriminator(6, 256, 256), 40 | '128': Discriminator(6, 128, 128), 41 | '64': Discriminator(6, 64, 64), 42 | } 43 | 44 | self.downx2 = nn.UpsamplingNearest2d(scale_factor = 0.5) 45 | self.upx2 = nn.UpsamplingNearest2d(scale_factor = 2) 46 | 47 | 48 | # LOSS FUNCTIONS 49 | self.criterion_GAN = torch.nn.MSELoss() 50 | self.criterion_cycle = torch.nn.L1Loss() 51 | self.criterion_MSE = torch.nn.MSELoss() 52 | self.criterion_content = L1_ASL() 53 | self.Loss = L1_ASL() 54 | 55 | # Initialize weights 56 | for key in self.Net_Demoire.keys(): 57 | self.Net_Demoire[key].apply(self.weights_init) 58 | 59 | for key in self.G_Artifact.keys(): 60 | self.G_Artifact[key].apply(self.weights_init) 61 | 62 | for key in self.D_moire.keys(): 63 | self.D_moire[key].apply(self.weights_init) 64 | 65 | for key in self.D_clear.keys(): 66 | self.D_clear[key].apply(self.weights_init) 67 | 68 | 69 | # Custom weights initialization called on network 70 | def weights_init(m): 71 | if isinstance(m, nn.Conv2d): 72 | nn.init.kaiming_uniform_(m.weight) 73 | if m.bias is not None: 74 | m.bias.data.zero_() 75 | 76 | 77 | def forward(self, MOIRE, CLEAR, historgram, device): 78 | 79 | Tensor = torch.cuda.FloatTensor 80 | 81 | # load data 82 | MOIRE_256 = MOIRE 83 | MOIRE_128 = self.downx2(MOIRE_256) 84 | MOIRE_64 = self.downx2(MOIRE_128) 85 | 86 | CLEAR_256 = CLEAR 87 | CLEAR_128 = self.downx2(CLEAR_256) 88 | CLEAR_64 = self.downx2(CLEAR_128) 89 | 90 | historgram = historgram.float() 91 | 92 | valid_256 = Variable(Tensor(MOIRE_256.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False) 93 | fake_256 = Variable(Tensor(MOIRE_256.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False) 94 | 95 | valid_128 = Variable(Tensor(MOIRE_128.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False) 96 | fake_128 = Variable(Tensor(MOIRE_128.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False) 97 | 98 | valid_64 = Variable(Tensor(MOIRE_64.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False) 99 | fake_64 = Variable(Tensor(MOIRE_64.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False) 100 | 101 | for resolution in self.resolution_dict.keys(): 102 | 103 | Net_Demoire = self.Net_Demoire[self.resolution_dict[resolution]['Net_Demoire']] 104 | 105 | if resolution == '64': 106 | G_Artifact_1 = self.G_Artifact[self.resolution_dict[resolution]['G_Artifact'][0]] 107 | G_Artifact_2 = self.G_Artifact[self.resolution_dict[resolution]['G_Artifact']] if resolution != '64' else self.G_Artifact[self.resolution_dict[resolution]['G_Artifact'][1]] 108 | 109 | D_moire = self.D_moire[self.resolution_dict[resolution]['D_moire']] 110 | D_clear = self.D_clear[self.resolution_dict[resolution]['D_clear']] 111 | 112 | -------------------------------------------------------------------------------- /unidemoire/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/diffusion/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/esdnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/esdnet/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/esdnet/nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of ESDNet for image demoireing 3 | """ 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch.nn.parameter import Parameter 11 | 12 | class ESDNet(nn.Module): 13 | def __init__(self, 14 | en_feature_num, 15 | en_inter_num, 16 | de_feature_num, 17 | de_inter_num, 18 | sam_number=1, 19 | ): 20 | super(ESDNet, self).__init__() 21 | self.encoder = Encoder(feature_num=en_feature_num, inter_num=en_inter_num, sam_number=sam_number) 22 | self.decoder = Decoder(en_num=en_feature_num, feature_num=de_feature_num, inter_num=de_inter_num, 23 | sam_number=sam_number) 24 | 25 | def forward(self, x): 26 | y_1, y_2, y_3 = self.encoder(x) 27 | out_1, out_2, out_3 = self.decoder(y_1, y_2, y_3) 28 | 29 | return out_1, out_2, out_3 30 | 31 | def _initialize_weights(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | m.weight.data.normal_(0.0, 0.02) 35 | if m.bias is not None: 36 | m.bias.data.normal_(0.0, 0.02) 37 | if isinstance(m, nn.ConvTranspose2d): 38 | m.weight.data.normal_(0.0, 0.02) 39 | 40 | 41 | class Decoder(nn.Module): 42 | def __init__(self, en_num, feature_num, inter_num, sam_number): 43 | super(Decoder, self).__init__() 44 | self.preconv_3 = conv_relu(4 * en_num, feature_num, 3, padding=1) 45 | self.decoder_3 = Decoder_Level(feature_num, inter_num, sam_number) 46 | 47 | self.preconv_2 = conv_relu(2 * en_num + feature_num, feature_num, 3, padding=1) 48 | self.decoder_2 = Decoder_Level(feature_num, inter_num, sam_number) 49 | 50 | self.preconv_1 = conv_relu(en_num + feature_num, feature_num, 3, padding=1) 51 | self.decoder_1 = Decoder_Level(feature_num, inter_num, sam_number) 52 | 53 | def forward(self, y_1, y_2, y_3): 54 | x_3 = y_3 55 | x_3 = self.preconv_3(x_3) 56 | out_3, feat_3 = self.decoder_3(x_3) 57 | 58 | x_2 = torch.cat([y_2, feat_3], dim=1) 59 | x_2 = self.preconv_2(x_2) 60 | out_2, feat_2 = self.decoder_2(x_2) 61 | 62 | x_1 = torch.cat([y_1, feat_2], dim=1) 63 | x_1 = self.preconv_1(x_1) 64 | out_1 = self.decoder_1(x_1, feat=False) 65 | 66 | return out_1, out_2, out_3 67 | 68 | 69 | class Encoder(nn.Module): 70 | def __init__(self, feature_num, inter_num, sam_number): 71 | super(Encoder, self).__init__() 72 | self.conv_first = nn.Sequential( 73 | nn.Conv2d(12, feature_num, kernel_size=5, stride=1, padding=2, bias=True), 74 | nn.ReLU(inplace=True) 75 | ) 76 | self.encoder_1 = Encoder_Level(feature_num, inter_num, level=1, sam_number=sam_number) 77 | self.encoder_2 = Encoder_Level(2 * feature_num, inter_num, level=2, sam_number=sam_number) 78 | self.encoder_3 = Encoder_Level(4 * feature_num, inter_num, level=3, sam_number=sam_number) 79 | 80 | def forward(self, x): 81 | x = F.pixel_unshuffle(x, 2) 82 | x = self.conv_first(x) 83 | 84 | out_feature_1, down_feature_1 = self.encoder_1(x) 85 | out_feature_2, down_feature_2 = self.encoder_2(down_feature_1) 86 | out_feature_3 = self.encoder_3(down_feature_2) 87 | 88 | return out_feature_1, out_feature_2, out_feature_3 89 | 90 | 91 | class Encoder_Level(nn.Module): 92 | def __init__(self, feature_num, inter_num, level, sam_number): 93 | super(Encoder_Level, self).__init__() 94 | self.rdb = RDB(in_channel=feature_num, d_list=(1, 2, 1), inter_num=inter_num) 95 | self.sam_blocks = nn.ModuleList() 96 | for _ in range(sam_number): 97 | sam_block = SAM(in_channel=feature_num, d_list=(1, 2, 3, 2, 1), inter_num=inter_num) 98 | self.sam_blocks.append(sam_block) 99 | 100 | if level < 3: 101 | self.down = nn.Sequential( 102 | nn.Conv2d(feature_num, 2 * feature_num, kernel_size=3, stride=2, padding=1, bias=True), 103 | nn.ReLU(inplace=True) 104 | ) 105 | self.level = level 106 | 107 | def forward(self, x): 108 | out_feature = self.rdb(x) 109 | for sam_block in self.sam_blocks: 110 | out_feature = sam_block(out_feature) 111 | if self.level < 3: 112 | down_feature = self.down(out_feature) 113 | return out_feature, down_feature 114 | return out_feature 115 | 116 | 117 | class Decoder_Level(nn.Module): 118 | def __init__(self, feature_num, inter_num, sam_number): 119 | super(Decoder_Level, self).__init__() 120 | self.rdb = RDB(feature_num, (1, 2, 1), inter_num) 121 | self.sam_blocks = nn.ModuleList() 122 | for _ in range(sam_number): 123 | sam_block = SAM(in_channel=feature_num, d_list=(1, 2, 3, 2, 1), inter_num=inter_num) 124 | self.sam_blocks.append(sam_block) 125 | self.conv = conv(in_channel=feature_num, out_channel=12, kernel_size=3, padding=1) 126 | 127 | def forward(self, x, feat=True): 128 | x = self.rdb(x) 129 | for sam_block in self.sam_blocks: 130 | x = sam_block(x) 131 | out = self.conv(x) 132 | out = F.pixel_shuffle(out, 2) 133 | 134 | if feat: 135 | feature = F.interpolate(x, scale_factor=2, mode='bilinear') 136 | return out, feature 137 | else: 138 | return out 139 | 140 | 141 | class DB(nn.Module): 142 | def __init__(self, in_channel, d_list, inter_num): 143 | super(DB, self).__init__() 144 | self.d_list = d_list 145 | self.conv_layers = nn.ModuleList() 146 | c = in_channel 147 | for i in range(len(d_list)): 148 | dense_conv = conv_relu(in_channel=c, out_channel=inter_num, kernel_size=3, dilation_rate=d_list[i], 149 | padding=d_list[i]) 150 | self.conv_layers.append(dense_conv) 151 | c = c + inter_num 152 | self.conv_post = conv(in_channel=c, out_channel=in_channel, kernel_size=1) 153 | 154 | def forward(self, x): 155 | t = x 156 | for conv_layer in self.conv_layers: 157 | _t = conv_layer(t) 158 | t = torch.cat([_t, t], dim=1) 159 | t = self.conv_post(t) 160 | return t 161 | 162 | 163 | class SAM(nn.Module): 164 | def __init__(self, in_channel, d_list, inter_num): 165 | super(SAM, self).__init__() 166 | self.basic_block = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num) 167 | self.basic_block_2 = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num) 168 | self.basic_block_4 = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num) 169 | self.fusion = CSAF(3 * in_channel) 170 | 171 | def forward(self, x): 172 | x_0 = x 173 | x_2 = F.interpolate(x, scale_factor=0.5, mode='bilinear') 174 | x_4 = F.interpolate(x, scale_factor=0.25, mode='bilinear') 175 | 176 | y_0 = self.basic_block(x_0) 177 | y_2 = self.basic_block_2(x_2) 178 | y_4 = self.basic_block_4(x_4) 179 | 180 | y_2 = F.interpolate(y_2, scale_factor=2, mode='bilinear') 181 | y_4 = F.interpolate(y_4, scale_factor=4, mode='bilinear') 182 | 183 | y = self.fusion(y_0, y_2, y_4) 184 | y = x + y 185 | 186 | return y 187 | 188 | 189 | class CSAF(nn.Module): 190 | def __init__(self, in_chnls, ratio=4): 191 | super(CSAF, self).__init__() 192 | self.squeeze = nn.AdaptiveAvgPool2d((1, 1)) 193 | self.compress1 = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0) 194 | self.compress2 = nn.Conv2d(in_chnls // ratio, in_chnls // ratio, 1, 1, 0) 195 | self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0) 196 | 197 | def forward(self, x0, x2, x4): 198 | out0 = self.squeeze(x0) 199 | out2 = self.squeeze(x2) 200 | out4 = self.squeeze(x4) 201 | out = torch.cat([out0, out2, out4], dim=1) 202 | out = self.compress1(out) 203 | out = F.relu(out) 204 | out = self.compress2(out) 205 | out = F.relu(out) 206 | out = self.excitation(out) 207 | out = F.sigmoid(out) 208 | w0, w2, w4 = torch.chunk(out, 3, dim=1) 209 | x = x0 * w0 + x2 * w2 + x4 * w4 210 | 211 | return x 212 | 213 | 214 | class RDB(nn.Module): 215 | def __init__(self, in_channel, d_list, inter_num): 216 | super(RDB, self).__init__() 217 | self.d_list = d_list 218 | self.conv_layers = nn.ModuleList() 219 | c = in_channel 220 | for i in range(len(d_list)): 221 | dense_conv = conv_relu(in_channel=c, out_channel=inter_num, kernel_size=3, dilation_rate=d_list[i], 222 | padding=d_list[i]) 223 | self.conv_layers.append(dense_conv) 224 | c = c + inter_num 225 | self.conv_post = conv(in_channel=c, out_channel=in_channel, kernel_size=1) 226 | 227 | def forward(self, x): 228 | t = x 229 | for conv_layer in self.conv_layers: 230 | _t = conv_layer(t) 231 | t = torch.cat([_t, t], dim=1) 232 | 233 | t = self.conv_post(t) 234 | return t + x 235 | 236 | 237 | class conv(nn.Module): 238 | def __init__(self, in_channel, out_channel, kernel_size, dilation_rate=1, padding=0, stride=1): 239 | super(conv, self).__init__() 240 | self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride, 241 | padding=padding, bias=True, dilation=dilation_rate) 242 | 243 | def forward(self, x_input): 244 | out = self.conv(x_input) 245 | return out 246 | 247 | 248 | class conv_relu(nn.Module): 249 | def __init__(self, in_channel, out_channel, kernel_size, dilation_rate=1, padding=0, stride=1): 250 | super(conv_relu, self).__init__() 251 | self.conv = nn.Sequential( 252 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride, 253 | padding=padding, bias=True, dilation=dilation_rate), 254 | nn.ReLU(inplace=True) 255 | ) 256 | 257 | def forward(self, x_input): 258 | out = self.conv(x_input) 259 | return out 260 | -------------------------------------------------------------------------------- /unidemoire/models/mbcnn/MBCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .MBCNN_class import * 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | class MBCNN(nn.Module): 10 | def __init__(self, nFilters, multi=True): 11 | super().__init__() 12 | self.imagesize = 256 13 | self.sigmoid = nn.Sigmoid() 14 | self.Space2Depth1 = nn.PixelUnshuffle(2) 15 | self.Depth2space1 = nn.PixelShuffle(2) 16 | 17 | self.conv_func1 = conv_relu1(12, nFilters * 2, 3, padding=1) 18 | self.pre_block1 = pre_block((1, 2, 3, 2, 1)) 19 | self.conv_func2 = conv_relu1(128, nFilters * 2, 3, padding=0, stride=2) 20 | self.pre_block2 = pre_block((1, 2, 3, 2, 1)) 21 | 22 | self.conv_func3 = conv_relu1(128, nFilters * 2, 3, padding=0, stride=2) 23 | self.pre_block3 = pre_block((1, 2, 2, 2, 1)) 24 | self.global_block1 = global_block(self.imagesize // 8) 25 | self.pos_block1 = pos_block((1, 2, 2, 2, 1)) 26 | self.conv1 = conv1(128, 12, 3,us=[True,False]) 27 | 28 | self.conv_func4 = conv_relu1(131, nFilters * 2, 1, padding=0,cat_shape=(3,nFilters*2),set_cat_mul=(False,True)) 29 | self.global_block2 = global_block(self.imagesize // 4) 30 | self.pre_block4 = pre_block((1, 2, 3, 2, 1)) 31 | self.global_block3 = global_block(self.imagesize // 4) 32 | self.pos_block2 = pos_block((1, 2, 3, 2, 1)) 33 | self.conv2 = conv1(128, 12, 3,us=[True,False]) 34 | 35 | self.conv_func5 = conv_relu1(131, nFilters * 2, 1, padding=0,cat_shape=(3,nFilters*2),set_cat_mul=(False,True)) 36 | 37 | self.global_block4 = global_block(self.imagesize // 2) 38 | self.pre_block5 = pre_block((1, 2, 3, 2, 1)) 39 | self.global_block5 = global_block(self.imagesize // 2) 40 | self.pos_block3 = pos_block((1, 2, 3, 2, 1)) 41 | self.conv3 = conv1(128, 12, 3,us=[True,False]) 42 | 43 | def forward(self, x): 44 | output_list = [] 45 | shape = list(x.shape) # [2, 3, 512, 512] 46 | # batch, channel, height, width = shape 47 | _x = self.Space2Depth1(x) 48 | t1 = self.conv_func1(_x) 49 | t1 = self.pre_block1(t1) 50 | 51 | t2 = F.pad(t1, (1, 1, 1, 1)) 52 | t2 = self.conv_func2(t2) 53 | t2 = self.pre_block2(t2) 54 | t3 = F.pad(t2, (1, 1, 1, 1)) 55 | t3 = self.conv_func3(t3) 56 | t3 = self.pre_block3(t3) 57 | t3 = self.global_block1(t3) 58 | t3 = self.pos_block1(t3) 59 | t3_out = self.conv1(t3) 60 | t3_out = self.Depth2space1(t3_out) 61 | t3_out = F.sigmoid(t3_out) 62 | output_list.append(t3_out) 63 | 64 | _t2 = torch.cat([t3_out, t2], dim=-3) 65 | _t2 = self.conv_func4(_t2) 66 | _t2 = self.global_block2(_t2) 67 | _t2 = self.pre_block4(_t2) 68 | _t2 = self.global_block3(_t2) 69 | _t2 = self.pos_block2(_t2) 70 | t2_out = self.conv2(_t2) 71 | t2_out = self.Depth2space1(t2_out) 72 | t2_out = F.sigmoid(t2_out) 73 | output_list.append(t2_out) 74 | 75 | _t1 = torch.cat([t1, t2_out], dim=-3) 76 | _t1 = self.conv_func5(_t1) 77 | _t1 = self.global_block4(_t1) 78 | _t1 = self.pre_block5(_t1) 79 | _t1 = self.global_block5(_t1) 80 | _t1 = self.pos_block3(_t1) 81 | _t1 = self.conv3(_t1) 82 | y = self.Depth2space1(_t1) 83 | 84 | y = self.sigmoid(y) + torch.Tensor([1e-10]).to(_t1.device) 85 | output_list.append(y) 86 | return t3_out,t2_out,y 87 | #return output_list 88 | 89 | # import os 90 | # from torchinfo import summary 91 | # from rich import print 92 | # GPU_ID = 5 93 | # os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % GPU_ID 94 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 95 | # net = MBCNN(64).to(device) 96 | # #print(summary(net, input_size=(1, 3, 512, 512))) 97 | # # print(summary(net)) 98 | 99 | # # model_stats = summary( 100 | # # net, 101 | # # input_size=(1, 3, 512, 512), 102 | # # verbose=1, 103 | # # col_names=["kernel_size", "output_size", "num_params"], 104 | # # row_settings=["var_names"], 105 | # # ) 106 | -------------------------------------------------------------------------------- /unidemoire/models/mbcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/mbcnn/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/mbcnn/arch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | from torch import Tensor 8 | from typing import Optional, List 9 | import pdb 10 | 11 | def make_divisible(v, divisor=8, min_value=8): 12 | if min_value is None: 13 | min_value = divisor 14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 15 | # Make sure that round down does not go down by more than 10%. 16 | if new_v < 0.9 * v: 17 | new_v += divisor 18 | return int(new_v) 19 | 20 | def initialize_weights(net_l, scale=1): 21 | if not isinstance(net_l, list): 22 | net_l = [net_l] 23 | for net in net_l: 24 | for m in net.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 27 | m.weight.data *= scale # for residual block 28 | if m.bias is not None: 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.Linear): 31 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 32 | m.weight.data *= scale 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif isinstance(m, nn.BatchNorm2d): 36 | init.constant_(m.weight, 1) 37 | init.constant_(m.bias.data, 0.0) 38 | 39 | class MeanShift(nn.Conv2d): 40 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 41 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 42 | std = torch.Tensor(rgb_std) 43 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 44 | self.weight.data.div_(std.view(3, 1, 1, 1)) 45 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 46 | self.bias.data.div_(std) 47 | # self.requires_grad = False 48 | for p in self.parameters(): 49 | p.requires_grad = False 50 | 51 | class USConv2d(nn.Conv2d): 52 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, us=[False, False],cat_shape=None,set_cat_mul=None): 53 | super(USConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 54 | self.width_mult = 1 55 | self.us = us 56 | self.cat_shape = cat_shape 57 | self.set_cat_mul = set_cat_mul 58 | self.in_channels_index_list = [None]*4 59 | 60 | self.unrank = False 61 | 62 | def forward(self, inputs): 63 | in_channels = inputs.shape[1] // self.groups if self.us[0] else self.in_channels // self.groups 64 | out_channels = int(self.out_channels * self.width_mult) if self.us[1] else self.out_channels 65 | if self.width_mult < 0.3: 66 | in_channels_index=self.in_channels_index_list[-1] 67 | elif self.width_mult < 0.6: 68 | in_channels_index=self.in_channels_index_list[-2] 69 | elif self.width_mult < 0.8: 70 | in_channels_index=self.in_channels_index_list[-3] 71 | else: 72 | in_channels_index=self.in_channels_index_list[-4] 73 | 74 | if in_channels == self.in_channels: 75 | weight = self.weight[:out_channels, :in_channels, :, :] 76 | elif in_channels_index is None and self.cat_shape is not None: 77 | if self.set_cat_mul is None: 78 | cat_num = len(self.cat_shape) 79 | inchannel_index = np.zeros(self.in_channels) 80 | start = 0 81 | for i in range(cat_num): 82 | inchannel_index[start:start+int(self.width_mult*self.cat_shape[i])]=1 83 | start += self.cat_shape[i] 84 | else: 85 | assert len(self.set_cat_mul) == len(self.cat_shape), 'USconv2d use cat now and partially prune, need len(self.set_cat_mul) == len(self.cat_shape)' 86 | inchannel_index = np.zeros(self.in_channels) 87 | start=0 88 | for i in range(len(self.set_cat_mul)): 89 | if self.set_cat_mul[i] == True: 90 | inchannel_index[start:start+int(self.width_mult*self.cat_shape[i])]=1 91 | else: 92 | inchannel_index[start:start+int(self.cat_shape[i])]=1 93 | start += self.cat_shape[i] 94 | # pdb.set_trace() 95 | inchannel_index = np.squeeze(np.argwhere(inchannel_index)) 96 | in_channels_index = inchannel_index 97 | weight = self.weight[:out_channels,inchannel_index, :, :] 98 | elif in_channels_index is not None and self.cat_shape is not None: 99 | inchannel_index = in_channels_index 100 | weight = self.weight[:out_channels,inchannel_index, :, :] 101 | else: 102 | weight = self.weight[:out_channels, :in_channels, :, :] 103 | 104 | if self.bias is not None: 105 | bias = self.bias[:out_channels] 106 | else: 107 | bias = self.bias 108 | y = F.conv2d(inputs, weight, bias, self.stride, self.padding, self.dilation, self.groups) 109 | # self.y = y 110 | return y 111 | 112 | 113 | class USConvTranspose2d(nn.ConvTranspose2d): 114 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, us=[False, False]): 115 | super(USConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding = output_padding) 116 | self.width_mult = None 117 | self.us = us 118 | 119 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: 120 | # in_channels = make_divisible(self.in_channels * self.width_mult) if self.us[0] else self.in_channels 121 | in_channels = int(self.in_channels * self.width_mult) if self.us[0] else self.in_channels 122 | out_channels = input.shape[1] if self.us[1] else self.out_channels 123 | 124 | 125 | weight = self.weight[:in_channels, :out_channels, :, :] 126 | 127 | assert isinstance(self.padding, tuple) 128 | output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type] 129 | 130 | return F.conv_transpose2d(input, weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation) 131 | 132 | 133 | class USBatchNorm2d(nn.BatchNorm2d): 134 | def __init__(self, num_features, width_list = None): 135 | super(USBatchNorm2d, self).__init__(num_features, affine=True, track_running_stats=False) 136 | self.width_id = None 137 | 138 | self.bn = nn.ModuleList([ 139 | nn.BatchNorm2d(self.num_features, affine=False) for _ in range(len(width_list)) 140 | ]) 141 | # raise NotImplementedError 142 | 143 | def forward(self, inputs): 144 | num_features = inputs.size(1) 145 | y = F.batch_norm( 146 | inputs, 147 | self.bn[self.width_id].running_mean[:num_features], 148 | self.bn[self.width_id].running_var[:num_features], 149 | self.weight[:num_features], 150 | self.bias[:num_features], 151 | self.training, 152 | self.momentum, 153 | self.eps) 154 | return y 155 | -------------------------------------------------------------------------------- /unidemoire/models/moire_blending.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | # from torch.nn.parameter import Parameter 5 | from torch.optim import lr_scheduler 6 | import torch.optim as optim 7 | from omegaconf import OmegaConf 8 | import glob 9 | 10 | from .MIB.Blending import Blending 11 | from .TRN.model import Uformer 12 | 13 | from .utils.loss_util import * 14 | from .utils.common import * 15 | 16 | torch.autograd.set_detect_anomaly(True) 17 | 18 | class MoireBlending_Model(pl.LightningModule): 19 | def __init__(self, model_name, network_config, loss_config=None, optimizer_config=None, ckpt_path=None, ignore_keys=[]): 20 | super().__init__() 21 | self.model_name = model_name 22 | self.network_config = network_config 23 | self.loss_config = loss_config 24 | self.optimizer_config = optimizer_config 25 | self.init_blending_args = network_config["init_blending_args"] 26 | self.blending_network_args = network_config["blending_network_args"] 27 | 28 | # model 29 | self.model = self.build_up_models() 30 | self.loss_fn = self.loss_function() 31 | 32 | if self.model_name == "UniDemoire": 33 | self.init_blend, self.refine_net = self.model 34 | if ckpt_path is not None: 35 | print(f"Loading Checkpoint from {ckpt_path}") 36 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 37 | 38 | 39 | def on_load_checkpoint(self, checkpoint): 40 | print("Loading checkpoint...") 41 | 42 | 43 | def on_train_batch_start(self, batch, batch_idx, dataloader_idx): 44 | # only for very first batch 45 | if self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0: 46 | self.moire_image_encoder = None 47 | 48 | def get_config_from_ckpt_path(self, ckpt_path): 49 | if os.path.isfile(ckpt_path): 50 | # paths = opt.resume.split("/") 51 | try: 52 | logdir = '/'.join(ckpt_path.split('/')[:-1]) 53 | print(f'Encoder dir is {logdir}') 54 | except ValueError: 55 | paths = ckpt_path.split("/") 56 | idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt 57 | logdir = "/".join(paths[:idx]) 58 | ckpt = ckpt_path 59 | else: 60 | assert os.path.isdir(ckpt_path), f"{ckpt_path} is not a directory" 61 | logdir = ckpt_path.rstrip("/") 62 | ckpt = os.path.join(logdir, "model.ckpt") 63 | 64 | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml"))) 65 | base = base_configs 66 | configs = [OmegaConf.load(cfg) for cfg in base] 67 | return configs[0]['model'] 68 | 69 | def val_mode(self, model): 70 | model = model.eval() 71 | for param in model.parameters(): 72 | param.requires_grad = False 73 | return model 74 | 75 | def init_from_ckpt(self, path, ignore_keys=list()): 76 | sd = torch.load(path, map_location="cpu")["state_dict"] 77 | keys = list(sd.keys()) 78 | for k in keys: 79 | for ik in ignore_keys: 80 | if k.startswith(ik): 81 | print("Deleting key {} from state_dict.".format(k)) 82 | del sd[k] 83 | self.load_state_dict(sd, strict=False) 84 | # print(f"Restored from {path}, self.mib_weight = {self.mib_weight}, self.init_blend.weight = {self.init_blend.weight}", sd['init_blend.weight']) 85 | 86 | def build_up_models(self): 87 | if self.model_name == "UniDemoire": 88 | init_blend = Blending(self.init_blending_args) 89 | refine_net = Uformer( 90 | embed_dim=self.blending_network_args['embed_dim'], 91 | depths=self.blending_network_args['depths'], 92 | win_size=self.blending_network_args['win_size'], 93 | modulator=self.blending_network_args['modulator'], 94 | shift_flag=self.blending_network_args['shift_flag'] 95 | ) 96 | model = [init_blend, refine_net] 97 | else: 98 | model = None 99 | return model 100 | 101 | def loss_function(self): 102 | if self.model_name == "UniDemoire": 103 | Perceptual_Loss = PerceptualLoss() 104 | TV_Loss = TVLoss() 105 | ColorHistogram_Loss = ColorHistogramMatchingLoss() 106 | loss_fn = [Perceptual_Loss, TV_Loss, ColorHistogram_Loss] 107 | else: 108 | loss_fn = [] 109 | 110 | return loss_fn 111 | 112 | def setup_optimizer(self): 113 | if self.model_name == "UniDemoire": 114 | optimizer = optim.Adam( 115 | [{ 116 | 'params': 117 | list(self.model[1].parameters()), # self.refine_net 118 | 'initial_lr': 119 | self.learning_rate, 120 | 'lr': self.learning_rate 121 | }], 122 | betas=(self.optimizer_config["beta1"], self.optimizer_config["beta2"]) 123 | ) 124 | else: 125 | optimizer = optim.Adam(params=self.model.parameters(), lr=self.learning_rate) 126 | 127 | return optimizer 128 | 129 | def setup_scheduler(self): 130 | if self.model_name == "UniDemoire": 131 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts( 132 | self.optimizer, 133 | T_0=self.optimizer_config["T_0"], 134 | T_mult=self.optimizer_config["T_mult"], 135 | eta_min=self.optimizer_config["eta_min"], 136 | ) 137 | else: 138 | scheduler = None 139 | 140 | return scheduler 141 | 142 | def configure_optimizers(self): 143 | self.optimizer = self.setup_optimizer() 144 | self.scheduler = self.setup_scheduler() 145 | if self.scheduler is not None: 146 | return [self.optimizer],[self.scheduler] 147 | else: 148 | return self.optimizer 149 | 150 | def training_epoch_end(self, outputs): 151 | if self.scheduler is not None: 152 | self.scheduler.step() 153 | 154 | def get_input(self, batch): 155 | moire_pattern = batch['moire_pattern'] 156 | natural = batch['natural'] 157 | real_moire = batch['real_moire'] 158 | number = batch['number'] 159 | 160 | return moire_pattern, natural, real_moire, number 161 | 162 | def forward(self, moire_pattern, natural, real_moire): 163 | if self.model_name == "UniDemoire": 164 | moire_pattern = moire_pattern.to(self.device) 165 | natural = natural.to(self.device) 166 | real_moire = real_moire.to(self.device) 167 | self.init_blend.to(self.device) 168 | 169 | ##* Here's the MIB: 170 | mib_result, weight = self.init_blend(natural, moire_pattern) 171 | mib_result = mib_result.to(self.device) 172 | self.log('w_mib', weight, prog_bar=True, logger=True) 173 | 174 | ##* And here's the TRN: 175 | refine_result = mib_result * self.refine_net(mib_result, real_moire) 176 | min_val = torch.min(refine_result) 177 | max_val = torch.max(refine_result) 178 | refine_result = (refine_result - min_val) / (max_val - min_val) 179 | refine_result = refine_result 180 | 181 | return mib_result, refine_result 182 | else: 183 | return None 184 | 185 | def training_step(self, batch, batch_idx): 186 | if self.model_name == "UniDemoire": 187 | # Get data 188 | moire_pattern, natural, real_moire, number = self.get_input(batch) 189 | # Get Loss function 190 | Perceptual_Loss, TV_Loss, ColorHistogram_Loss = self.loss_fn 191 | #* Get the result 192 | mib_result, refine_result = self(moire_pattern, natural, real_moire) 193 | 194 | #* Calculate the losses 195 | content_loss = Perceptual_Loss(input=refine_result, target=mib_result, device=self.device, feature_layers=[0,1,2]) 196 | color_loss = ColorHistogram_Loss(x=refine_result, y=real_moire, device=self.device) 197 | tv_loss = TV_Loss(refine_result) 198 | 199 | #** Total Loss: 200 | loss = color_loss + content_loss + 0.1 * tv_loss 201 | 202 | # Logging 203 | self.log('L_p', content_loss, prog_bar=True, logger=True) 204 | self.log('L_c', color_loss, prog_bar=True, logger=True) 205 | self.log('L_tv', tv_loss, prog_bar=True, logger=True) 206 | self.log('L_total', loss, prog_bar=False, logger=True) 207 | 208 | lr = self.optimizer.param_groups[0]['lr'] 209 | self.log('lr', lr, prog_bar=True, logger=False) 210 | 211 | return loss 212 | 213 | def feature_norm(self, feature): 214 | normed_feature = feature / feature.norm(dim=-1, keepdim=True) 215 | return normed_feature 216 | 217 | @torch.no_grad() 218 | def log_images(self, batch, only_inputs=False, **kwargs): 219 | log = dict() 220 | moire_pattern, natural, real_moire, number = self.get_input(batch) 221 | log["natural"] = natural 222 | log["moire_pattern"] = moire_pattern 223 | log["real_moire"] = real_moire 224 | if not only_inputs: 225 | mib_result, trn_result = self(moire_pattern, natural, real_moire) 226 | log["init_blending_result"] = mib_result 227 | log["fusion_result"] = trn_result 228 | 229 | return log -------------------------------------------------------------------------------- /unidemoire/models/pmtnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/pmtnet/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/shooting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/shooting/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/shooting/image_transformer.py: -------------------------------------------------------------------------------- 1 | #from utils import * 2 | import numpy as np 3 | import cv2 4 | from math import pi 5 | import torch 6 | import time 7 | import torch.nn.functional as F 8 | 9 | class ImageTransformer(object): 10 | """ Perspective transformation class for image 11 | with shape (height, width, #channels) """ 12 | 13 | def __init__(self, img): 14 | self.image = img # (h, w, c) 15 | self.image = self.image.unsqueeze(0) # (1, h, w, c) 16 | # self.image = self.image.permute(0, 3, 1, 2) # (1, c, h, w) 17 | self.batchsize = self.image.shape[0] 18 | self.num_channels = self.image.shape[1] 19 | self.height = self.image.shape[2] 20 | self.width = self.image.shape[3] 21 | self.device = img.device 22 | 23 | def get_rad(self, theta, phi, gamma): 24 | return (self.deg_to_rad(theta), 25 | self.deg_to_rad(phi), 26 | self.deg_to_rad(gamma)) 27 | 28 | def get_deg(self, rtheta, rphi, rgamma): 29 | return (self.rad_to_deg(rtheta), 30 | self.rad_to_deg(rphi), 31 | self.rad_to_deg(rgamma)) 32 | 33 | def deg_to_rad(self, deg): 34 | return deg * pi / 180.0 35 | 36 | def rad_to_deg(self, rad): 37 | return rad * 180.0 / pi 38 | 39 | """ Wrapper of Rotating a Image """ 40 | def rotate_along_axis(self, random_f, theta=0, phi=0, gamma=0, dx=0, dy=0, dz=0): 41 | 42 | # Get radius of rotation along 3 axes 43 | if random_f: 44 | theta = np.random.randint(-20, 20) 45 | phi = np.random.randint(-20, 20) 46 | gamma = np.random.randint(-20, 20) 47 | 48 | # theta = 0 49 | # phi = 0 50 | # gamma = 0 51 | rtheta, rphi, rgamma =self.get_rad(theta, phi, gamma) 52 | 53 | # Get ideal focal length on z axis 54 | # NOTE: Change this section to other axis if needed 55 | d = np.sqrt(self.height**2 + self.width**2) 56 | 57 | self.focal = d / (2 * np.sin(rgamma) if np.sin(rgamma) != 0 else 1) 58 | dz = self.focal 59 | 60 | # Get projection matrix 61 | mat = self.get_M(rtheta, rphi, rgamma, dx, dy, dz) 62 | 63 | # print(type(mat), mat.shape) 64 | # mat_inv = np.linalg.pinv(mat) 65 | mat_inv = mat 66 | 67 | time.sleep(0.1) 68 | rotate_img = cv2.warpPerspective(self.image.cpu().numpy(), mat, (self.width, self.height)) 69 | # rotate_img = self.image.cpu() 70 | 71 | rotate_img = torch.from_numpy(rotate_img) 72 | return theta, phi, gamma, rotate_img, mat_inv, mat 73 | 74 | def Perspective(self, random_f, theta=0, phi=0, gamma=0, dx=0, dy=0, dz=0): 75 | 76 | # Get radius of rotation along 3 axes 77 | if random_f: 78 | theta = torch.randint(-20,20,(1,)) 79 | phi = torch.randint(-20,20,(1,)) 80 | gamma = torch.randint(-20,20,(1,)) 81 | rtheta, rphi, rgamma =self.get_rad(theta, phi, gamma) 82 | 83 | # Get ideal focal length on z axis 84 | # NOTE: Change this section to other axis if needed 85 | d = torch.sqrt(torch.tensor(self.height**2) + torch.tensor(self.width**2)) 86 | self.focal = d / (2 * torch.sin(rgamma) if torch.sin(rgamma) != 0 else 1) 87 | dz = self.focal 88 | 89 | # Get projection matrix 90 | mat = self.get_M_2(rtheta, rphi, rgamma, dx, dy, dz) 91 | 92 | # rotate_img = cv2.warpPerspective(self.image.cpu().numpy(), mat, (self.width, self.height)) 93 | rotate_img = self.warpPerspective(image=self.image, M=mat) 94 | 95 | return theta, phi, gamma, rotate_img 96 | 97 | 98 | def warpPerspective(self, image, M): 99 | M_norm = self.matrix_normalization(M) 100 | grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), image.size(), align_corners=False).to(self.device) 101 | homogeneous_grid = torch.cat([grid, torch.ones(self.batchsize, self.height, self.width, 1, device=self.device)], dim=-1) 102 | 103 | warped_grid = torch.matmul(homogeneous_grid, M_norm.transpose(1, 2)) 104 | warped_grid_xy = warped_grid[..., :2] / warped_grid[..., 2:3] 105 | 106 | transformed_image = F.grid_sample(image, warped_grid_xy, align_corners=False, padding_mode='zeros') 107 | 108 | return transformed_image 109 | 110 | 111 | def matrix_normalization(self, M_cv): 112 | M_cv = M_cv.unsqueeze(0) 113 | B = M_cv.shape[0] 114 | H = self.height 115 | W = self.width 116 | device = self.device 117 | 118 | norm_matrix = torch.tensor([ 119 | [2.0/W, 0, -1], 120 | [ 0, 2.0/H, -1], 121 | [ 0, 0, 1] 122 | ], dtype=torch.float32, device=device).unsqueeze(0).repeat(B, 1, 1) 123 | 124 | inv_norm_matrix = torch.tensor([ 125 | [W/2.0, 0, W/2.0], 126 | [ 0, H/2.0, H/2.0], 127 | [ 0, 0, 1] 128 | ], dtype=torch.float32, device=device).unsqueeze(0).repeat(B, 1, 1) 129 | 130 | M_norm = torch.bmm(torch.bmm(norm_matrix, torch.inverse(M_cv)), inv_norm_matrix) 131 | 132 | return M_norm 133 | 134 | def get_M_2(self, theta, phi, gamma, dx, dy, dz): 135 | w = self.width 136 | h = self.height 137 | f = self.focal 138 | 139 | # Projection 2D -> 3D matrix 140 | A1 = torch.tensor([ [1, 0, -w/2], 141 | [0, 1, -h/2], 142 | [0, 0, 1], 143 | [0, 0, 1] ], dtype=torch.float32, device=self.device) 144 | 145 | # Rotation matrices around the X, Y, and Z axis 146 | RX = torch.tensor([ [1, 0, 0, 0], 147 | [0, torch.cos(theta), -torch.sin(theta), 0], 148 | [0, torch.sin(theta), torch.cos(theta), 0], 149 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device) 150 | 151 | RY = torch.tensor([ [torch.cos(phi), 0, -torch.sin(phi), 0], 152 | [0, 1, 0, 0], 153 | [torch.sin(phi), 0, torch.cos(phi), 0], 154 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device) 155 | 156 | RZ = torch.tensor([ [torch.cos(gamma), -torch.sin(gamma), 0, 0], 157 | [torch.sin(gamma), torch.cos(gamma), 0, 0], 158 | [0, 0, 1, 0], 159 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device) 160 | 161 | # Composed rotation matrix with (RX, RY, RZ) 162 | R = torch.matmul(torch.matmul(RX, RY), RZ) 163 | 164 | # Translation matrix 165 | T = torch.tensor([ [1, 0, 0, dx], 166 | [0, 1, 0, dy], 167 | [0, 0, 1, dz], 168 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device) 169 | 170 | # Projection 3D -> 2D matrix 171 | A2 = torch.tensor([ [f, 0, w/2, 0], 172 | [0, f, h/2, 0], 173 | [0, 0, 1, 0] ], dtype=torch.float32, device=self.device) 174 | 175 | # Final transformation matrix 176 | M = torch.matmul(A2, torch.matmul(T, torch.matmul(R, A1))) 177 | 178 | return M 179 | 180 | 181 | """ Get Perspective Projection Matrix """ 182 | def get_M(self, theta, phi, gamma, dx, dy, dz): 183 | w = self.width 184 | h = self.height 185 | f = self.focal 186 | 187 | # Projection 2D -> 3D matrix 188 | A1 = np.array([ [1, 0, -w/2], 189 | [0, 1, -h/2], 190 | [0, 0, 1], 191 | [0, 0, 1]]) 192 | 193 | # Rotation matrices around the X, Y, and Z axis 194 | RX = np.array([ [1, 0, 0, 0], 195 | [0, np.cos(theta), -np.sin(theta), 0], 196 | [0, np.sin(theta), np.cos(theta), 0], 197 | [0, 0, 0, 1]]) 198 | 199 | RY = np.array([ [np.cos(phi), 0, -np.sin(phi), 0], 200 | [0, 1, 0, 0], 201 | [np.sin(phi), 0, np.cos(phi), 0], 202 | [0, 0, 0, 1]]) 203 | 204 | RZ = np.array([ [np.cos(gamma), -np.sin(gamma), 0, 0], 205 | [np.sin(gamma), np.cos(gamma), 0, 0], 206 | [0, 0, 1, 0], 207 | [0, 0, 0, 1]]) 208 | 209 | # Composed rotation matrix with (RX, RY, RZ) 210 | R = np.dot(np.dot(RX, RY), RZ) 211 | 212 | # Translation matrix 213 | T = np.array([ [1, 0, 0, dx], 214 | [0, 1, 0, dy], 215 | [0, 0, 1, dz], 216 | [0, 0, 0, 1]]) 217 | 218 | # Projection 3D -> 2D matrix 219 | A2 = np.array([ [f, 0, w/2, 0], 220 | [0, f, h/2, 0], 221 | [0, 0, 1, 0]]) 222 | 223 | # Final transformation matrix 224 | M = np.dot(A2, np.dot(T, np.dot(R, A1))) 225 | 226 | return M -------------------------------------------------------------------------------- /unidemoire/models/shooting/method.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.utils 10 | import torchvision.transforms as transforms 11 | 12 | from PIL import Image 13 | 14 | from unidemoire.models.shooting.mosaicing_demosaicing_v2 import * 15 | from unidemoire.models.shooting.image_transformer import ImageTransformer 16 | 17 | def adjust_contrast_and_brightness(input_img, beta = 30): 18 | beta = beta / 255.0 #* 亮度强度 19 | input_img = torch.clamp(input_img + beta, 0, 1) 20 | 21 | return input_img 22 | 23 | def simulate_LCD_display(input_img, device): 24 | """ Simulate the display of raw images on LCD screen 25 | Input: 26 | original images (tensor): batch x channel x height x width 27 | Output: 28 | LCD images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor) 29 | """ 30 | b, c, h, w = input_img.shape 31 | 32 | simulate_imgs = torch.zeros((b, c, h * 3, w * 3), dtype=torch.float32, device=device) 33 | red = input_img[:, 0, :, :].repeat_interleave(3, dim=1) 34 | green = input_img[:, 1, :, :].repeat_interleave(3, dim=1) 35 | blue = input_img[:, 2, :, :].repeat_interleave(3, dim=1) 36 | 37 | simulate_imgs[:, 0, :, 0::3] = red 38 | simulate_imgs[:, 1, :, 1::3] = green 39 | simulate_imgs[:, 2, :, 2::3] = blue 40 | 41 | return simulate_imgs 42 | 43 | 44 | def demosaic_and_denoise(input_img, device): 45 | """ Apply demosaicing to the images 46 | Input: 47 | images (tensor): batch x (height x scale_factor) x (width x scale_factor) 48 | Output: 49 | demosaicing images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor) 50 | """ 51 | input_img = input_img.double() 52 | demosaicing_imgs = demosaicing_CFA_Bayer_bilinear(input_img) 53 | demosaicing_imgs = demosaicing_imgs.permute(0, 3, 1, 2) 54 | return demosaicing_imgs 55 | 56 | def simulate_CFA(input_img, device): 57 | """ Simulate the raw reading of the camera sensor using bayer CFA 58 | Input: 59 | images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor) 60 | Output: 61 | mosaicing images (tensor): batch x (height x scale_factor) x (width x scale_factor) 62 | """ 63 | input_img = input_img.permute(0, 2, 3, 1) 64 | mosaicing_imgs = mosaicing_CFA_Bayer(input_img) 65 | return mosaicing_imgs 66 | 67 | def random_rotation_3(org_images, lcd_images, device): 68 | """ Simulate the 3D rotatation during the shooting 69 | Input: 70 | images (tensor): batch x channel x height x width 71 | Rotate angle: 72 | theta (int): (-20, 20) 73 | phi (int): (-20, 20) 74 | gamma (int): (-20, 20) 75 | Output: 76 | rotated original images (tensor): batch x channel x height x width 77 | rotated LCD images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor) 78 | """ 79 | rotate_images = torch.zeros_like(org_images).to(device) # (bs, c, h, w) 80 | rotate_lcd_images = torch.zeros_like(lcd_images).to(device) # (bs, c, 3h, 3w) 81 | 82 | for n, img in enumerate(org_images): 83 | 84 | Trans_org = ImageTransformer(img) 85 | Trans_lcd = ImageTransformer(lcd_images[n]) 86 | 87 | theta, phi, gamma, rotate_img = Trans_org.Perspective(random_f=True) 88 | _, _, _, rotate_lcd_img = Trans_lcd.Perspective(random_f=False, theta=theta, phi=phi, gamma=gamma) 89 | 90 | rotate_img = rotate_img.squeeze(0) 91 | rotate_lcd_img = rotate_lcd_img.squeeze(0) 92 | 93 | rotate_images[n, :] = rotate_img 94 | rotate_lcd_images[n, :] = rotate_lcd_img 95 | 96 | return rotate_images, rotate_lcd_images 97 | 98 | 99 | def Shooting(org_imgs, device): 100 | batch_size, channel, img_h, img_w = org_imgs.shape 101 | alpha = random.randint(1,4) 102 | crop_ratio = 0.7 103 | 104 | noise = torch.randn([batch_size, img_h * alpha * 3, img_w * alpha * 3]).to(device) 105 | noise = noise / 256.0 106 | 107 | resize_before_lcd = F.interpolate(org_imgs, scale_factor=alpha, mode="bilinear", align_corners=True) 108 | lcd_images = simulate_LCD_display(resize_before_lcd, device) 109 | rotate_images, rotate_lcd_images = random_rotation_3(org_imgs, lcd_images, device) 110 | 111 | cfa_img = simulate_CFA(rotate_lcd_images, device) 112 | cfa_img_noise = cfa_img + noise 113 | # cfa_img_noise = cfa_img_noise.double() 114 | demosaic_img = demosaic_and_denoise(cfa_img_noise, device) 115 | brighter_img = adjust_contrast_and_brightness(demosaic_img, beta=20) 116 | 117 | at_images = F.interpolate(brighter_img, [img_h, img_w], mode='bilinear', align_corners=True) 118 | at_images = torch.clamp(at_images, min=0, max=1) 119 | 120 | crop_edges = transforms.Compose([ 121 | transforms.CenterCrop((int(img_h*crop_ratio), int(img_w*crop_ratio))), 122 | transforms.Resize((img_h, img_w)), 123 | 124 | ]) 125 | rotate_images = crop_edges(rotate_images) 126 | at_images = crop_edges(at_images) 127 | 128 | return at_images, rotate_images 129 | 130 | 131 | 132 | trans = transforms.Compose([ 133 | transforms.Resize((384,384)), 134 | transforms.ToTensor() 135 | ]) 136 | -------------------------------------------------------------------------------- /unidemoire/models/shooting/mosaicing_demosaicing_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from colour_demosaicing.bayer import masks_CFA_Bayer 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def demosaicing_CFA_Bayer_bilinear(CFA, pattern='RGGB'): 8 | """ 9 | Returns the demosaiced *RGB* colourspace array from given *Bayer* CFA using 10 | bilinear interpolation. 11 | 12 | Parameters 13 | ---------- 14 | CFA : array_like 15 | *Bayer* CFA. 16 | pattern : unicode, optional 17 | **{'RGGB', 'BGGR', 'GRBG', 'GBRG'}**, 18 | Arrangement of the colour filters on the pixel array. 19 | 20 | Returns 21 | ------- 22 | ndarray 23 | *RGB* colourspace array. 24 | 25 | Notes 26 | ----- 27 | - The definition output is not clipped in range [0, 1] : this allows for 28 | direct HDRI / radiance image generation on *Bayer* CFA data and post 29 | demosaicing of the high dynamic range data as showcased in this 30 | `Jupyter Notebook `__. 33 | 34 | References 35 | ---------- 36 | :cite:`Losson2010c` 37 | 38 | Examples 39 | -------- 40 | >>> import numpy as np 41 | >>> CFA = np.array( 42 | ... [[0.30980393, 0.36078432, 0.30588236, 0.3764706], 43 | ... [0.35686275, 0.39607844, 0.36078432, 0.40000001]]) 44 | >>> demosaicing_CFA_Bayer_bilinear(CFA) 45 | array([[[ 0.69705884, 0.17941177, 0.09901961], 46 | [ 0.46176472, 0.4509804 , 0.19803922], 47 | [ 0.45882354, 0.27450981, 0.19901961], 48 | [ 0.22941177, 0.5647059 , 0.30000001]], 49 | 50 | [[ 0.23235295, 0.53529412, 0.29705883], 51 | [ 0.15392157, 0.26960785, 0.59411766], 52 | [ 0.15294118, 0.4509804 , 0.59705884], 53 | [ 0.07647059, 0.18431373, 0.90000002]]]) 54 | >>> CFA = np.array( 55 | ... [[0.3764706, 0.360784320, 0.40784314, 0.3764706], 56 | ... [0.35686275, 0.30980393, 0.36078432, 0.29803923]]) 57 | >>> demosaicing_CFA_Bayer_bilinear(CFA, 'BGGR') 58 | array([[[ 0.07745098, 0.17941177, 0.84705885], 59 | [ 0.15490197, 0.4509804 , 0.5882353 ], 60 | [ 0.15196079, 0.27450981, 0.61176471], 61 | [ 0.22352942, 0.5647059 , 0.30588235]], 62 | 63 | [[ 0.23235295, 0.53529412, 0.28235295], 64 | [ 0.4647059 , 0.26960785, 0.19607843], 65 | [ 0.45588237, 0.4509804 , 0.20392157], 66 | [ 0.67058827, 0.18431373, 0.10196078]]]) 67 | """ 68 | 69 | ## Above is the original version on mosaicing_demosaicing package processing image based on numpy arrays, we adapt it to a torch tensor version as follows: 70 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 | batch, h, w= CFA.size() 72 | 73 | R_m, G_m, B_m = masks_CFA_Bayer([h, w], pattern) 74 | 75 | R_m = R_m[np.newaxis, np.newaxis, :] 76 | R_m = np.repeat(R_m, batch, axis = 0) 77 | G_m = G_m[np.newaxis, np.newaxis, :] 78 | G_m = np.repeat(G_m, batch, axis=0) 79 | B_m = B_m[np.newaxis, np.newaxis, :] 80 | B_m = np.repeat(B_m, batch, axis=0) 81 | 82 | R_m = torch.from_numpy(R_m).to(device) 83 | G_m = torch.from_numpy(G_m).to(device) 84 | B_m = torch.from_numpy(B_m).to(device) 85 | 86 | H_G = np.array( 87 | [[0, 1, 0], 88 | [1, 4, 1], 89 | [0, 1, 0]]) / 4 # yapf: disable 90 | 91 | H_G = H_G[np.newaxis, np.newaxis, :] 92 | H_G = torch.from_numpy(H_G).to(device) 93 | 94 | H_RB = np.array( 95 | [[1, 2, 1], 96 | [2, 4, 2], 97 | [1, 2, 1]]) / 4 # yapf: disable 98 | 99 | H_RB = H_RB[np.newaxis, np.newaxis, :] 100 | H_RB = torch.from_numpy(H_RB).to(device) 101 | CFA = CFA.unsqueeze(1) 102 | 103 | R = F.conv2d(CFA * R_m, H_RB, stride=1, padding=1) 104 | G = F.conv2d(CFA * G_m, H_G, stride=1, padding=1) 105 | B = F.conv2d(CFA * B_m, H_RB, stride=1, padding=1) 106 | 107 | R = R.squeeze(1) 108 | G = G.squeeze(1) 109 | B = B.squeeze(1) 110 | 111 | del R_m, G_m, B_m, H_RB, H_G 112 | torch.cuda.empty_cache() 113 | 114 | return torch.stack((R, G, B), dim = 3) 115 | 116 | def mosaicing_CFA_Bayer(RGB, pattern = 'RGGB'): 117 | """ 118 | Returns the *Bayer* CFA mosaic for a given *RGB* colourspace array. 119 | 120 | Parameters 121 | ---------- 122 | RGB : array_like 123 | *RGB* colourspace array. 124 | pattern : unicode, optional 125 | **{'RGGB', 'BGGR', 'GRBG', 'GBRG'}**, 126 | Arrangement of the colour filters on the pixel array. 127 | 128 | Returns 129 | ------- 130 | ndarray 131 | *Bayer* CFA mosaic. 132 | 133 | Examples 134 | -------- 135 | >>> import numpy as np 136 | >>> RGB = np.array([[[0, 1, 2], 137 | ... [0, 1, 2]], 138 | ... [[0, 1, 2], 139 | ... [0, 1, 2]]]) 140 | >>> mosaicing_CFA_Bayer(RGB) 141 | array([[ 0., 1.], 142 | [ 1., 2.]]) 143 | >>> mosaicing_CFA_Bayer(RGB, pattern='BGGR') 144 | array([[ 2., 1.], 145 | [ 1., 0.]]) 146 | """ 147 | 148 | ## Above is the original version on mosaicing_demosaicing package processing image based on numpy arrays, we adapt it to a torch tensor version as follows: 149 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 150 | 151 | R = RGB[:, :, :, 0] 152 | G = RGB[:, :, :, 1] 153 | B = RGB[:, :, :, 2] 154 | 155 | batch, _, _, _ = RGB.shape 156 | R_m, G_m, B_m = masks_CFA_Bayer(RGB.shape[1:3], pattern) 157 | 158 | G_m = G_m[np.newaxis, :] 159 | G_m = np.repeat(G_m, batch, axis = 0) 160 | B_m = B_m[np.newaxis, :] 161 | B_m = np.repeat(B_m, batch, axis = 0) 162 | R_m = R_m[np.newaxis, :] 163 | R_m = np.repeat(R_m, batch, axis = 0) 164 | 165 | R_m = torch.from_numpy(R_m).to(device) 166 | G_m = torch.from_numpy(G_m).to(device) 167 | B_m = torch.from_numpy(B_m).to(device) 168 | 169 | CFA = R * R_m + G * G_m + B * B_m 170 | del R_m, G_m, B_m 171 | torch.cuda.empty_cache() 172 | 173 | return CFA 174 | 175 | -------------------------------------------------------------------------------- /unidemoire/models/undem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/undem/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/utils/__init__.py -------------------------------------------------------------------------------- /unidemoire/models/utils/matlab_ssim.py: -------------------------------------------------------------------------------- 1 | """ 2 | A pytorch implementation for reproducing results in MATLAB, slightly modified from 3 | https://github.com/mayorx/matlab_ssim_pytorch_implementation. 4 | """ 5 | 6 | import torch 7 | import cv2 8 | import numpy as np 9 | 10 | def generate_1d_gaussian_kernel(): 11 | return cv2.getGaussianKernel(11, 1.5) 12 | 13 | def generate_2d_gaussian_kernel(): 14 | kernel = generate_1d_gaussian_kernel() 15 | return np.outer(kernel, kernel.transpose()) 16 | 17 | def generate_3d_gaussian_kernel(): 18 | kernel = generate_1d_gaussian_kernel() 19 | window = generate_2d_gaussian_kernel() 20 | return np.stack([window * k for k in kernel], axis=0) 21 | 22 | class MATLAB_SSIM(torch.nn.Module): 23 | def __init__(self, device='cpu'): 24 | super(MATLAB_SSIM, self).__init__() 25 | self.device = device 26 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 27 | conv3d.weight.requires_grad = False 28 | conv3d.weight[0, 0, :, :, :] = torch.tensor(generate_3d_gaussian_kernel()) 29 | self.conv3d = conv3d.to(device) 30 | 31 | conv2d = torch.nn.Conv2d(1, 1, (11, 11), stride=1, padding=(5, 5), bias=False, padding_mode='replicate') 32 | conv2d.weight.requires_grad = False 33 | conv2d.weight[0, 0, :, :] = torch.tensor(generate_2d_gaussian_kernel()) 34 | self.conv2d = conv2d.to(device) 35 | 36 | def forward(self, img1, img2, device='cuda'): 37 | assert len(img1.shape) == len(img2.shape) 38 | self.device = device 39 | self.conv2d = self.conv2d.to(self.device) 40 | self.conv3d = self.conv3d.to(self.device) 41 | with torch.no_grad(): 42 | img1 = torch.tensor(img1).to(self.device).float() 43 | img2 = torch.tensor(img2).to(self.device).float() 44 | 45 | if len(img1.shape) == 2: 46 | conv = self.conv2d 47 | elif len(img1.shape) == 3: 48 | conv = self.conv3d 49 | else: 50 | raise not NotImplementedError('only support 2d / 3d images.') 51 | return self._ssim(img1, img2, conv) 52 | 53 | def _ssim(self, img1, img2, conv): 54 | img1 = img1.unsqueeze(0).unsqueeze(0) 55 | img2 = img2.unsqueeze(0).unsqueeze(0) 56 | 57 | C1 = (0.01 * 255) ** 2 58 | C2 = (0.03 * 255) ** 2 59 | 60 | mu1 = conv(img1) 61 | mu2 = conv(img2) 62 | 63 | mu1_sq = mu1 ** 2 64 | mu2_sq = mu2 ** 2 65 | mu1_mu2 = mu1 * mu2 66 | sigma1_sq = conv(img1 ** 2) - mu1_sq 67 | sigma2_sq = conv(img2 ** 2) - mu2_sq 68 | sigma12 = conv(img1 * img2) - mu1_mu2 69 | 70 | ssim_map = ((2 * mu1_mu2 + C1) * 71 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 72 | (sigma1_sq + sigma2_sq + C2)) 73 | 74 | return float(ssim_map.mean()) 75 | -------------------------------------------------------------------------------- /unidemoire/models/utils/metric.py: -------------------------------------------------------------------------------- 1 | from .common import SSIM, PSNR, tensor2img 2 | from skimage.metrics import peak_signal_noise_ratio as ski_psnr 3 | from skimage.metrics import structural_similarity as ski_ssim 4 | from unidemoire.models.utils.matlab_ssim import MATLAB_SSIM 5 | import lpips 6 | import torch 7 | import numpy as np 8 | from math import log10 9 | 10 | class create_metrics(): 11 | """ 12 | We note that for different benchmarks, previous works calculate metrics in different ways, which might 13 | lead to inconsistent SSIM results (and slightly different PSNR), and thus we follow their individual 14 | ways to compute metrics on each individual dataset for fair comparisons. 15 | For our 4K dataset, calculating metrics for 4k image is much time-consuming, 16 | thus we benchmark evaluations for all methods with a fast pytorch SSIM implementation referred from 17 | "https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py". 18 | """ 19 | def __init__(self, dataset, device): 20 | self.data_type = dataset 21 | # self.lpips_fn = lpips.LPIPS(net='alex').cuda() 22 | self.lpips_fn = lpips.LPIPS(net='alex') 23 | # self.lpips_fn = self.lpips_fn.to(device) 24 | self.fast_ssim = SSIM() 25 | self.fast_psnr = PSNR() 26 | self.matlab_ssim = MATLAB_SSIM(device=device) 27 | 28 | def compute(self, out_img, gt, device=None): 29 | if self.data_type == 'UHDM': 30 | res_psnr, res_ssim = self.fast_psnr_ssim(out_img, gt) 31 | elif self.data_type == 'FHDMi': 32 | res_psnr, res_ssim = self.skimage_psnr_ssim(out_img, gt) 33 | elif self.data_type == 'TIP': 34 | res_psnr, res_ssim = self.matlab_psnr_ssim(out_img, gt, device) 35 | elif self.data_type == 'AIM': 36 | res_psnr, res_ssim = self.aim_psnr_ssim(out_img, gt) 37 | else: 38 | print('Unrecognized data_type for evaluation!') 39 | raise NotImplementedError 40 | pre = torch.clamp(out_img, min=0, max=1) 41 | tar = torch.clamp(gt, min=0, max=1) 42 | self.lpips_fn = self.lpips_fn.to(device) 43 | res_lpips = self.lpips_fn.forward(pre, tar, normalize=True).item() 44 | 45 | return res_lpips, res_psnr, res_ssim 46 | 47 | 48 | def fast_psnr_ssim(self, out_img, gt): 49 | pre = torch.clamp(out_img, min=0, max=1) 50 | tar = torch.clamp(gt, min=0, max=1) 51 | psnr = self.fast_psnr(pre, tar) 52 | ssim = self.fast_ssim(pre, tar) 53 | return psnr, ssim 54 | 55 | def skimage_psnr_ssim(self, out_img, gt): 56 | """ 57 | Same with the previous SOTA FHDe2Net: https://github.com/PKU-IMRE/FHDe2Net/blob/main/test.py 58 | """ 59 | mi1 = tensor2img(out_img) 60 | mt1 = tensor2img(gt) 61 | 62 | psnr = ski_psnr(mt1, mi1) 63 | ssim = ski_ssim(mt1, mi1, multichannel=True, channel_axis=2) 64 | return psnr, ssim 65 | 66 | def matlab_psnr_ssim(self, out_img, gt, device): 67 | """ 68 | A pytorch implementation for reproducing SSIM results when using MATLAB 69 | same with the previous SOTA MopNet: https://github.com/PKU-IMRE/MopNet/blob/master/test_with_matlabcode.m 70 | """ 71 | mi1 = tensor2img(out_img) 72 | mt1 = tensor2img(gt) 73 | psnr = ski_psnr(mt1, mi1) 74 | ssim = self.matlab_ssim(mt1, mi1, device) 75 | return psnr, ssim 76 | 77 | def aim_psnr_ssim(self, out_img, gt): 78 | """ 79 | Same with the previous SOTA MBCNN: https://github.com/zhenngbolun/Learnbale_Bandpass_Filter/blob/master/main_multiscale.py 80 | """ 81 | mi1 = tensor2img(out_img) 82 | mt1 = tensor2img(gt) 83 | mi1 = mi1.astype(np.float32) / 255.0 84 | mt1 = mt1.astype(np.float32) / 255.0 85 | psnr = 10 * log10(1 / np.mean((mt1 - mi1) ** 2)) 86 | ssim = ski_ssim(mt1, mi1, multichannel=True) 87 | return psnr, ssim -------------------------------------------------------------------------------- /unidemoire/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from unidemoire.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /unidemoire/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /unidemoire/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from unidemoire.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /unidemoire/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/distributions/__init__.py -------------------------------------------------------------------------------- /unidemoire/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /unidemoire/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /unidemoire/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/encoders/__init__.py -------------------------------------------------------------------------------- /unidemoire/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from unidemoire.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("models/bert") 59 | 60 | self.device = device 61 | self.vq_interface = vq_interface 62 | self.max_length = max_length 63 | 64 | def forward(self, text): 65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 67 | tokens = batch_encoding["input_ids"].to(self.device) 68 | return tokens 69 | 70 | @torch.no_grad() 71 | def encode(self, text): 72 | tokens = self(text) 73 | if not self.vq_interface: 74 | return tokens 75 | return None, None, [None, None, tokens] 76 | 77 | def decode(self, text): 78 | return text 79 | 80 | 81 | class BERTEmbedder(AbstractEncoder): 82 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 85 | super().__init__() 86 | self.use_tknz_fn = use_tokenizer 87 | if self.use_tknz_fn: 88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 89 | self.device = device 90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 91 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 92 | emb_dropout=embedding_dropout) 93 | 94 | def forward(self, text): 95 | if self.use_tknz_fn: 96 | tokens = self.tknz_fn(text)#.to(self.device) 97 | else: 98 | tokens = text 99 | z = self.transformer(tokens, return_embeddings=True) 100 | return z 101 | 102 | def encode(self, text): 103 | # output of length 77 104 | return self(text) 105 | 106 | 107 | class SpatialRescaler(nn.Module): 108 | def __init__(self, 109 | n_stages=1, 110 | method='bilinear', 111 | multiplier=0.5, 112 | in_channels=3, 113 | out_channels=None, 114 | bias=False): 115 | super().__init__() 116 | self.n_stages = n_stages 117 | assert self.n_stages >= 0 118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 119 | self.multiplier = multiplier 120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 121 | self.remap_output = out_channels is not None 122 | if self.remap_output: 123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 125 | 126 | def forward(self,x): 127 | for stage in range(self.n_stages): 128 | x = self.interpolator(x, scale_factor=self.multiplier) 129 | 130 | 131 | if self.remap_output: 132 | x = self.channel_mapper(x) 133 | return x 134 | 135 | def encode(self, x): 136 | return self(x) 137 | 138 | 139 | class FrozenCLIPTextEmbedder(nn.Module): 140 | """ 141 | Uses the CLIP transformer encoder for text. 142 | """ 143 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 144 | super().__init__() 145 | self.model, _ = clip.load(version, jit=False, device="cpu") 146 | self.device = device 147 | self.max_length = max_length 148 | self.n_repeat = n_repeat 149 | self.normalize = normalize 150 | 151 | def freeze(self): 152 | self.model = self.model.eval() 153 | for param in self.parameters(): 154 | param.requires_grad = False 155 | 156 | def forward(self, text): 157 | tokens = clip.tokenize(text).to(self.device) 158 | z = self.model.encode_text(tokens) 159 | if self.normalize: 160 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 161 | return z 162 | 163 | def encode(self, text): 164 | z = self(text) 165 | if z.ndim==2: 166 | z = z[:, None, :] 167 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 168 | return z 169 | 170 | 171 | class FrozenClipImageEmbedder(nn.Module): 172 | """ 173 | Uses the CLIP image encoder. 174 | """ 175 | def __init__( 176 | self, 177 | model, 178 | jit=False, 179 | device='cuda' if torch.cuda.is_available() else 'cpu', 180 | antialias=False, 181 | ): 182 | super().__init__() 183 | self.model, _ = clip.load(name=model, device=device, jit=jit) 184 | 185 | self.antialias = antialias 186 | 187 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 188 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 189 | 190 | def preprocess(self, x): 191 | # normalize to [0,1] 192 | x = kornia.geometry.resize(x, (224, 224), 193 | interpolation='bicubic',align_corners=True, 194 | antialias=self.antialias) 195 | x = (x + 1.) / 2. 196 | # renormalize according to clip 197 | x = kornia.enhance.normalize(x, self.mean, self.std) 198 | return x 199 | 200 | def forward(self, x): 201 | # x is assumed to be in range [-1,1] 202 | return self.model.encode_image(self.preprocess(x)) 203 | 204 | -------------------------------------------------------------------------------- /unidemoire/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from unidemoire.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /unidemoire/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | 49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 50 | if self.perceptual_weight > 0: 51 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 52 | rec_loss = rec_loss + self.perceptual_weight * p_loss 53 | 54 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 55 | weighted_nll_loss = nll_loss 56 | if weights is not None: 57 | weighted_nll_loss = weights*nll_loss 58 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 59 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 60 | 61 | kl_loss = posteriors.kl() 62 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 63 | 64 | # now the GAN part 65 | if optimizer_idx == 0: 66 | # generator update 67 | if cond is None: 68 | assert not self.disc_conditional 69 | logits_fake = self.discriminator(reconstructions.contiguous()) 70 | else: 71 | assert self.disc_conditional 72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 73 | g_loss = -torch.mean(logits_fake) 74 | 75 | if self.disc_factor > 0.0: 76 | try: 77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 78 | except RuntimeError: 79 | assert not self.training 80 | d_weight = torch.tensor(0.0) 81 | else: 82 | d_weight = torch.tensor(0.0) 83 | 84 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 85 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 86 | 87 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 88 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 89 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 90 | "{}/d_weight".format(split): d_weight.detach(), 91 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 92 | "{}/g_loss".format(split): g_loss.detach().mean(), 93 | } 94 | return loss, log 95 | 96 | if optimizer_idx == 1: 97 | # second pass for discriminator update 98 | if cond is None: 99 | logits_real = self.discriminator(inputs.contiguous().detach()) 100 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 101 | else: 102 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 103 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 104 | 105 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 106 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 107 | 108 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 109 | "{}/logits_real".format(split): logits_real.detach().mean(), 110 | "{}/logits_fake".format(split): logits_fake.detach().mean() 111 | } 112 | return d_loss, log 113 | 114 | -------------------------------------------------------------------------------- /unidemoire/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | from unidemoire.util import exists 11 | 12 | 13 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 14 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 15 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 16 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 17 | loss_real = (weights * loss_real).sum() / weights.sum() 18 | loss_fake = (weights * loss_fake).sum() / weights.sum() 19 | d_loss = 0.5 * (loss_real + loss_fake) 20 | return d_loss 21 | 22 | def adopt_weight(weight, global_step, threshold=0, value=0.): 23 | if global_step < threshold: 24 | weight = value 25 | return weight 26 | 27 | 28 | def measure_perplexity(predicted_indices, n_embed): 29 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 30 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 31 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 32 | avg_probs = encodings.mean(0) 33 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 34 | cluster_use = torch.sum(avg_probs > 0) 35 | return perplexity, cluster_use 36 | 37 | def l1(x, y): 38 | return torch.abs(x-y) 39 | 40 | 41 | def l2(x, y): 42 | return torch.pow((x-y), 2) 43 | 44 | 45 | class VQLPIPSWithDiscriminator(nn.Module): 46 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 47 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 48 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 49 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 50 | pixel_loss="l1"): 51 | super().__init__() 52 | assert disc_loss in ["hinge", "vanilla"] 53 | assert perceptual_loss in ["lpips", "clips", "dists"] 54 | assert pixel_loss in ["l1", "l2"] 55 | self.codebook_weight = codebook_weight 56 | self.pixel_weight = pixelloss_weight 57 | if perceptual_loss == "lpips": 58 | print(f"{self.__class__.__name__}: Running with LPIPS.") 59 | self.perceptual_loss = LPIPS().eval() 60 | else: 61 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 62 | self.perceptual_weight = perceptual_weight 63 | 64 | if pixel_loss == "l1": 65 | self.pixel_loss = l1 66 | else: 67 | self.pixel_loss = l2 68 | 69 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 70 | n_layers=disc_num_layers, 71 | use_actnorm=use_actnorm, 72 | ndf=disc_ndf 73 | ).apply(weights_init) 74 | self.discriminator_iter_start = disc_start 75 | if disc_loss == "hinge": 76 | self.disc_loss = hinge_d_loss 77 | elif disc_loss == "vanilla": 78 | self.disc_loss = vanilla_d_loss 79 | else: 80 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 81 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 82 | self.disc_factor = disc_factor 83 | self.discriminator_weight = disc_weight 84 | self.disc_conditional = disc_conditional 85 | self.n_classes = n_classes 86 | 87 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 88 | if last_layer is not None: 89 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 90 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 91 | else: 92 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 93 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 94 | 95 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 96 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 97 | d_weight = d_weight * self.discriminator_weight 98 | return d_weight 99 | 100 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 101 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 102 | if not exists(codebook_loss): 103 | codebook_loss = torch.tensor([0.]).to(inputs.device) 104 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 105 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | if self.perceptual_weight > 0: 107 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 108 | rec_loss = rec_loss + self.perceptual_weight * p_loss 109 | else: 110 | p_loss = torch.tensor([0.0]) 111 | 112 | nll_loss = rec_loss 113 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 114 | nll_loss = torch.mean(nll_loss) 115 | 116 | # now the GAN part 117 | if optimizer_idx == 0: 118 | # generator update 119 | if cond is None: 120 | assert not self.disc_conditional 121 | logits_fake = self.discriminator(reconstructions.contiguous()) 122 | else: 123 | assert self.disc_conditional 124 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 125 | g_loss = -torch.mean(logits_fake) 126 | 127 | try: 128 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 129 | except RuntimeError: 130 | assert not self.training 131 | d_weight = torch.tensor(0.0) 132 | 133 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 134 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 135 | 136 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 137 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 138 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 139 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 140 | "{}/p_loss".format(split): p_loss.detach().mean(), 141 | "{}/d_weight".format(split): d_weight.detach(), 142 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 143 | "{}/g_loss".format(split): g_loss.detach().mean(), 144 | } 145 | if predicted_indices is not None: 146 | assert self.n_classes is not None 147 | with torch.no_grad(): 148 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 149 | log[f"{split}/perplexity"] = perplexity 150 | log[f"{split}/cluster_usage"] = cluster_usage 151 | return loss, log 152 | 153 | if optimizer_idx == 1: 154 | # second pass for discriminator update 155 | if cond is None: 156 | logits_real = self.discriminator(inputs.contiguous().detach()) 157 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 158 | else: 159 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 160 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 161 | 162 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 163 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 164 | 165 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 166 | "{}/logits_real".format(split): logits_real.detach().mean(), 167 | "{}/logits_fake".format(split): logits_fake.detach().mean() 168 | } 169 | return d_loss, log 170 | -------------------------------------------------------------------------------- /unidemoire/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | 8 | import multiprocessing as mp 9 | from threading import Thread 10 | from queue import Queue 11 | from inspect import isfunction 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | def log_txt_as_img(wh, xc, size=10): 15 | # wh a tuple of (width, height) 16 | # xc a list of captions to plot 17 | b = len(xc) 18 | txts = list() 19 | for bi in range(b): 20 | txt = Image.new("RGB", wh, color="white") 21 | draw = ImageDraw.Draw(txt) 22 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 23 | nc = int(40 * (wh[0] / 256)) 24 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 25 | 26 | try: 27 | draw.text((0, 0), lines, fill="black", font=font) 28 | except UnicodeEncodeError: 29 | print("Cant encode string for logging. Skipping.") 30 | 31 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 32 | txts.append(txt) 33 | txts = np.stack(txts) 34 | txts = torch.tensor(txts) 35 | return txts 36 | 37 | 38 | def ismap(x): 39 | if not isinstance(x, torch.Tensor): 40 | return False 41 | return (len(x.shape) == 4) and (x.shape[1] > 3) 42 | 43 | 44 | def isimage(x): 45 | if not isinstance(x, torch.Tensor): 46 | return False 47 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 48 | 49 | 50 | def exists(x): 51 | return x is not None 52 | 53 | 54 | def default(val, d): 55 | if exists(val): 56 | return val 57 | return d() if isfunction(d) else d 58 | 59 | 60 | def mean_flat(tensor): 61 | """ 62 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 63 | Take the mean over all non-batch dimensions. 64 | """ 65 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 66 | 67 | 68 | def count_params(model, verbose=False): 69 | total_params = sum(p.numel() for p in model.parameters()) 70 | if verbose: 71 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 72 | return total_params 73 | 74 | def instantiate_from_config(config): 75 | if not "target" in config: 76 | if config == '__is_first_stage__': 77 | return None 78 | elif config == "__is_unconditional__": 79 | return None 80 | raise KeyError("Expected key `target` to instantiate.") 81 | target_cls = get_obj_from_str(config["target"]) 82 | if target_cls is None: 83 | print(f"Warning: Target class {config['target']} not found, skipping instantiation.") 84 | return None 85 | return target_cls(**config.get("params", dict())) 86 | 87 | 88 | 89 | def get_obj_from_str(string, reload=False, silent=True): 90 | # module, cls = string.rsplit(".", 1) 91 | # if reload: 92 | # module_imp = importlib.port_module(module) 93 | # importlib.reload(module_imp) 94 | # return getattr(importlib.import_module(module, package=None), cls) 95 | try: 96 | module, cls = string.rsplit(".", 1) 97 | if reload: 98 | if module in sys.modules: 99 | importlib.reload(sys.modules[module]) 100 | return getattr(importlib.import_module(module, package=None), cls) 101 | except (ModuleNotFoundError, AttributeError) as e: 102 | if not silent: 103 | raise 104 | print(f"Warning: Could not import {string} - {str(e)}. Skipping...") 105 | return None 106 | 107 | 108 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 109 | # create dummy dataset instance 110 | 111 | # run prefetching 112 | if idx_to_fn: 113 | res = func(data, worker_id=idx) 114 | else: 115 | res = func(data) 116 | Q.put([idx, res]) 117 | Q.put("Done") 118 | 119 | 120 | def parallel_data_prefetch( 121 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 122 | ): 123 | # if target_data_type not in ["ndarray", "list"]: 124 | # raise ValueError( 125 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 126 | # ) 127 | if isinstance(data, np.ndarray) and target_data_type == "list": 128 | raise ValueError("list expected but function got ndarray.") 129 | elif isinstance(data, abc.Iterable): 130 | if isinstance(data, dict): 131 | print( 132 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 133 | ) 134 | data = list(data.values()) 135 | if target_data_type == "ndarray": 136 | data = np.asarray(data) 137 | else: 138 | data = list(data) 139 | else: 140 | raise TypeError( 141 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 142 | ) 143 | 144 | if cpu_intensive: 145 | Q = mp.Queue(1000) 146 | proc = mp.Process 147 | else: 148 | Q = Queue(1000) 149 | proc = Thread 150 | # spawn processes 151 | if target_data_type == "ndarray": 152 | arguments = [ 153 | [func, Q, part, i, use_worker_id] 154 | for i, part in enumerate(np.array_split(data, n_proc)) 155 | ] 156 | else: 157 | step = ( 158 | int(len(data) / n_proc + 1) 159 | if len(data) % n_proc != 0 160 | else int(len(data) / n_proc) 161 | ) 162 | arguments = [ 163 | [func, Q, part, i, use_worker_id] 164 | for i, part in enumerate( 165 | [data[i: i + step] for i in range(0, len(data), step)] 166 | ) 167 | ] 168 | processes = [] 169 | for i in range(n_proc): 170 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 171 | processes += [p] 172 | 173 | # start processes 174 | print(f"Start prefetching...") 175 | import time 176 | 177 | start = time.time() 178 | gather_res = [[] for _ in range(n_proc)] 179 | try: 180 | for p in processes: 181 | p.start() 182 | 183 | k = 0 184 | while k < n_proc: 185 | # get result 186 | res = Q.get() 187 | if res == "Done": 188 | k += 1 189 | else: 190 | gather_res[res[0]] = res[1] 191 | 192 | except Exception as e: 193 | print("Exception: ", e) 194 | for p in processes: 195 | p.terminate() 196 | 197 | raise e 198 | finally: 199 | for p in processes: 200 | p.join() 201 | print(f"Prefetching complete. [{time.time() - start} sec.]") 202 | 203 | if target_data_type == 'ndarray': 204 | if not isinstance(gather_res[0], np.ndarray): 205 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 206 | 207 | # order outputs 208 | return np.concatenate(gather_res, axis=0) 209 | elif target_data_type == 'list': 210 | out = [] 211 | for r in gather_res: 212 | out.extend(r) 213 | return out 214 | else: 215 | return gather_res 216 | --------------------------------------------------------------------------------