├── .gitignore
├── LICENSE
├── README.md
├── configs
├── autoencoder
│ └── vae-768-crop.yaml
├── demoire
│ ├── cross-dataset
│ │ ├── esdnet
│ │ │ ├── fhdmi
│ │ │ │ └── cd_unidemoire_esdnet_fhdmi.yaml
│ │ │ ├── tip
│ │ │ │ └── cd_unidemoire_esdnet_tip.yaml
│ │ │ └── uhdm
│ │ │ │ └── cd_unidemoire_esdnet_uhdm.yaml
│ │ └── mbcnn
│ │ │ ├── fhdmi
│ │ │ └── cd_unidemoire_mbcnn_fhdmi.yaml
│ │ │ ├── tip
│ │ │ └── cd_unidemoire_mbcnn_tip.yaml
│ │ │ └── uhdm
│ │ │ └── cd_unidemoire_mbcnn_uhdm.yaml
│ └── mhrnid
│ │ ├── mhrnid_esdnet_unidemoire.yaml
│ │ └── mhrnid_mbcnn_unidemoire.yaml
├── latent-diffusion
│ └── ldm-vae-768-crop.yaml
└── moire-blending
│ ├── fhdmi
│ └── blending_fhdmi.yaml
│ ├── tip
│ └── blending_tip.yaml
│ └── uhdm
│ └── blending_uhdm.yaml
├── environment.yaml
├── main.py
├── models
├── moire_blending
│ ├── fhdmi
│ │ └── config.yaml
│ ├── tip
│ │ └── config.yaml
│ └── uhdm
│ │ └── config.yaml
└── moire_generator
│ └── diffusion
│ └── config.yaml
├── scripts
└── sample_moire_pattern.py
├── setup.py
├── static
└── images
│ └── Pipeline.png
├── taming
└── modules
│ └── autoencoder
│ └── lpips
│ └── vgg.pth
└── unidemoire
├── __init__.py
├── data
├── __init__.py
├── fhdmi.py
├── moire.py
├── moire_blend.py
├── tip.py
├── uhdm.py
└── utils.py
├── lr_scheduler.py
├── models
├── MIB
│ ├── Blending.py
│ └── __init__.py
├── TRN
│ ├── __init__.py
│ └── model.py
├── autoencoder.py
├── cycle
│ ├── Models
│ │ ├── Loss_func_demoire.py
│ │ ├── models.py
│ │ ├── modules.py
│ │ └── utils.py
│ ├── nets.py
│ └── networks.py
├── diffusion
│ ├── __init__.py
│ ├── classifier.py
│ ├── ddim.py
│ ├── ddpm.py
│ └── plms.py
├── esdnet
│ ├── __init__.py
│ └── nets.py
├── mbcnn
│ ├── LossNet.py
│ ├── MBCNN.py
│ ├── MBCNN_class.py
│ ├── __init__.py
│ └── arch_util.py
├── moire_blending.py
├── moire_nets.py
├── pmtnet
│ ├── PMTNet.py
│ └── __init__.py
├── shooting
│ ├── __init__.py
│ ├── image_transformer.py
│ ├── method.py
│ └── mosaicing_demosaicing_v2.py
├── undem
│ ├── __init__.py
│ └── model.py
└── utils
│ ├── __init__.py
│ ├── common.py
│ ├── loss_util.py
│ ├── matlab_ssim.py
│ └── metric.py
├── modules
├── attention.py
├── diffusionmodules
│ ├── __init__.py
│ ├── model.py
│ ├── openaimodel.py
│ └── util.py
├── distributions
│ ├── __init__.py
│ └── distributions.py
├── ema.py
├── encoders
│ ├── __init__.py
│ └── modules.py
├── losses
│ ├── __init__.py
│ ├── contperceptual.py
│ └── vqperceptual.py
└── x_transformer.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | .python-version
5 |
6 | .vscode/
7 | .idea/
8 | *.swp
9 | *.swo
10 |
11 | .env
12 | venv/
13 | ENV/
14 | env.bak/
15 |
16 | .coverage
17 | htmlcov/
18 | .pytest_cache/
19 |
20 | *.pt
21 | *.bin
22 | *.npy
23 | *.npz
24 | *.tmp
25 | *.ckpt
26 |
27 | .DS_Store
28 | Thumbs.db
29 |
30 |
31 |
32 | # Jupyter笔记本检查点
33 | .ipynb_checkpoints/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 4DVLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UniDemoiré: Towards Universal Image Demoiréing with Data Generation and Synthesis
2 |
3 |
Zemin Yang1, Yujing Sun2, Xidong Peng1, Siu Ming Yiu2, Yuexin Ma1
4 |
5 | ### [Project Page](https://yizhifengyeyzm.github.io/UniDemoire-page/) | [Paper](https://arxiv.org/abs/2502.06324) | [Dataset](https://drive.google.com/drive/folders/1k48jcgJLMUB0_42H-x1VYl67NP56zel8?usp=drive_link)
6 |
7 | ***
8 |
9 | The generalization ability of SOTA demoiréing models is greatly limited by the scarcity of data. Therefore, we mainly face two challenges to obtain a universal model with improved generalization capability: To obtain a vast amount of **1) diverse** and **2) realistic-looking moiré data**. Notice that traditional moiré image datasets contain real data, but continuously expanding their size to involve more diversity is extremely time-consuming and impractical. While current synthesized datasets/methods struggle to synthesize realistic-looking moiré images.
10 |
11 | 
12 |
13 | Hence, to tackle these challenges, we introduce a universal solution, **UniDemoiré**. The data diversity challenge is solved by collecting a more diverse moiré pattern dataset and presenting a moiré pattern generator to increase further pattern variations. Meanwhile, the data realistic-looking challenge is undertaken by a moiré image synthesis module. Finally, our solution can produce realistic-looking moiré images of sufficient diversity, substantially enhancing the zero-shot and cross-domain performance of demoiréing models.
14 |
15 | ***
16 |
17 | ## :hourglass_flowing_sand: To Do
18 |
19 | - [x] Release training code
20 | - [x] Release testing code
21 | - [x] Release dataset
22 | - [x] Release pre-trained models
23 |
24 | ## 🛠️ Enviroment
25 | The entire UniDemoiré framework is built on the Latent Diffusion Model and requires Python 3.8 and PyTorch-Lightning 1.4.2.
26 | You can install the UniDemoiré environment in the following two ways:
27 | ```
28 | conda env create -f environment.yaml
29 | conda activate unidemoire
30 | ```
31 | If the installation doesn't go well you can also follow the [instructions](https://github.com/CompVis/latent-diffusion?tab=readme-ov-file#requirements) to install the Latent Diffusion Model environment first, and then install the rest via pip:
32 | ```
33 | conda activate unidemoire
34 |
35 | ...
36 | (install the ldm environment first)
37 | ...
38 |
39 | pip install colour-demosaicing==0.2.2
40 | pip install thop==0.1.1-2209072238
41 | pip install lpips==0.1.4
42 | pip install timm==0.9.16
43 | pip install pillow==9.5.0
44 | ```
45 |
46 | ## 📦 Dataset and Pre-trained Models
47 |
48 | We provide the captured 4K moiré pattern dataset, the sampled moiré pattern dataset, the MHRNID dataset, and the pre-trained models on both Moiré Pattern Generator and Moiré Image Synthesis stages, which can be downloaded through the following links:
49 |
50 | **\[[Baidu Drive](https://pan.baidu.com/s/1YI4NO5xyC8oK3ZOFHpTa1w?pwd=sthx)\]** | **\[[Google Drive](https://drive.google.com/drive/folders/1k48jcgJLMUB0_42H-x1VYl67NP56zel8?usp=drive_link)\]**
51 |
52 |
53 | ## 🚀 Getting Started
54 |
55 | >**Some important tips about the training and testing process of our code:**
56 |
57 | The style of the config file is similar to [ldm](https://github.com/CompVis/latent-diffusion), and **the paths to the training/testing datasets can be changed inside config.**
58 |
59 | Logs and checkpoints for trained models are saved to `logs/_`.
60 |
61 | **If you need to continue training on a specific model, then you can simply run the training code with the “`-r`” parameter and add your model ckpt path**
62 |
63 | The dataset type and path for the test set need to be specified by you in the config file. **The program will automatically start the testing process after training is complete (same pattern as in Latent Diffusion Model)**. If you want to change the test dataset, you need to change the config file, and then re-run your training code with “`-r`” to continue training in the previous step, and the program will go directly to the test session!
64 |
65 | If you want to train with multiple gpus, remember to replace `` with your gpu id in the code template below, and be sure to adjust the “`--gpus`” parameter that follows it as well
66 | - For example: if you want to train with `4` gpus (assuming that they are numbered `5`, `6`, `7`, and `8`), then in the code template you should type `CUDA_VISIBLE_DEVICES=5,6,7,8` and with `--gpus 0,1,2,3,`
67 |
68 | ### Moiré Pattern Generator
69 |
70 | #### 1. AutoEncoder
71 | Configs for training a KL-regularized autoencoder on captured moiré pattern dataset are provided at `configs/autoencoder`. Training can be started by running:
72 | ```
73 | CUDA_VISIBLE_DEVICES= python main.py --base configs/autoencoder/.yaml --scale_lr False -t --gpus 0,
74 | ```
75 | After training, place the ckpt file in `models/moire_generator/autoencoder`.
76 |
77 | #### 2. Diffusion Model
78 | In `configs/latent-diffusion/` we provide configs for training diffusion on captured moiré pattern dataset. Training can be started by running:
79 | ```
80 | CUDA_VISIBLE_DEVICES= python main.py --base configs/latent-diffusion/.yaml -t --gpus 0,
81 | ```
82 | After training, place the ckpt file in `models/moire_generator/diffusion`.
83 |
84 | #### 3. Sampling
85 | Run the script via:
86 | ```
87 | CUDA_VISIBLE_DEVICES= python scripts/sample_moire_pattern.py
88 | -r
90 |
91 | For example:
92 | CUDA_VISIBLE_DEVICES=0 python scripts/sample_moire_pattern.py
93 | -r ./models/moire_generator/diffusion/last.ckpt
94 | -n 10000
95 | ```
96 |
97 | ### Moiré Image Synthesis
98 | In `configs/moire-blending/` we provide configs for training the synthesis model on the UHDM, FHDMi, and TIP datasets. Training can be started by running:
99 | ```
100 | CUDA_VISIBLE_DEVICES= python main.py --base configs/moire-blending/.yaml --scale_lr False -t --gpus 0,
101 |
102 | For example: (training on UHDM dataset)
103 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/moire-blending/uhdm/blending_uhdm.yaml --scale_lr False -t --gpus 0,
104 | ```
105 | where `` is one of {`uhdm/blending_uhdm`, `fhdmi/blending_fhdmi`, `tip/blending_tip`}.
106 |
107 | After training, place the ckpt file in `models/moire_blending//`. You can find the original config file in these paths. If you want to change the training cofig in the `configs/moire-blending/`, then you also need to change the config file in `models/moire_blending//` accordingly.
108 |
109 | ### Demoiréing
110 |
111 | #### 1. Zero-Shot Demoiréing
112 | First, download and unzip the MHRNID dataset. **(to be updated)**
113 | Then run the following code to start training on MHRNID:
114 | ```
115 | CUDA_VISIBLE_DEVICES= python main.py --base configs/demoire/mhrnid/.yaml --scale_lr False -t --gpus 0,
116 |
117 | For example: (using ESDNet)
118 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/demoire/mhrnid/mhrnid_esdnet_unidemoire.yaml --scale_lr False -t --gpus 0,
119 | ```
120 | where `` is one of {`mhrnid_esdnet_unidemoire`, `mhrnid_mbcnn_unidemoire`}.
121 |
122 | #### 2. Cross-Dataset Demoiréing
123 |
124 | Run the following code to start training:
125 | ```
126 | CUDA_VISIBLE_DEVICES= python main.py --base configs/demoire/cross-dataset//.yaml --scale_lr False -t --gpus 0,
127 |
128 | For example: (using ESDNet, train on UHDM dataset)
129 | CUDA_VISIBLE_DEVICES=0 python main.py --base configs/demoire/cross-dataset/esdnet/cd_unidemoire_esdnet_uhdm.yaml --scale_lr False -t --gpus 0,
130 | ```
131 | where `` is one of {`esdnet`, `mbcnn`}, and `` is one of {`uhdm`, `fhdmi`, `tip`}.
132 |
133 |
134 |
135 | ## 🙏 Acknowledgements
136 |
137 | We would like to express our gratitude to the authors and contributors of the following projects:
138 |
139 | - [Latent Diffusion Model](https://github.com/CompVis/latent-diffusion)
140 | - [UHDM](https://github.com/CVMI-Lab/UHDM)
141 | - [FHDMi](https://github.com/PKU-IMRE/FHDe2Net)
142 | - [TIP](https://github.com/ZhengJun-AI/MoirePhotoRestoration-MCNN)
143 | - [Uformer](https://github.com/ZhendongWang6/Uformer)
144 | - [UnDeM](https://github.com/zysxmu/UnDeM)
145 |
146 |
147 |
148 | ## 📑 Citation
149 |
150 | If you find our work useful, please consider citing us using the following BibTeX entry:
151 |
152 | ```
153 | @misc{yang2025unidemoire,
154 | author = {Zemin Yang, Yujing Sun, Xidong Peng, Siu Ming Yiu, Yuexin Ma},
155 | title = {UniDemoir\'e: Towards Universal Image Demoir\'eing with Data Generation and Synthesis},
156 | year = {2025},
157 | eprint = {2502.06324},
158 | archivePrefix = {arXiv},
159 | primaryClass = {cs.CV},
160 | url={https://arxiv.org/abs/2502.06324},
161 | }
162 | ```
163 |
164 |
--------------------------------------------------------------------------------
/configs/autoencoder/vae-768-crop.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: unidemoire.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: unidemoire.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 | disc_in_channels: 3
14 |
15 | ddconfig:
16 | double_z: True
17 | z_channels: 64
18 | resolution: 768
19 | in_channels: 3
20 | out_ch: 3
21 | ch: 64
22 | ch_mult: [1,1,2,2,4,4]
23 | num_res_blocks: 2
24 | attn_resolutions: [16,8]
25 | dropout: 0.0
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 2
31 | wrap: True
32 | train:
33 | target: unidemoire.data.moire.MoirePattern
34 | params:
35 | dataset_path: "./data/captured_data" # Please set the path to your moire pattern dataset
36 | resolution: 768
37 |
38 | lightning:
39 | callbacks:
40 | image_logger:
41 | target: main.ImageLogger
42 | params:
43 | batch_frequency: 1000
44 | max_images: 8
45 | increase_log_steps: True
46 |
47 | trainer:
48 | benchmark: True
49 | accumulate_grad_batches: 2
50 |
51 |
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/esdnet/fhdmi/cd_unidemoire_esdnet_fhdmi.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model # ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | model_name: ESDNet
6 | mode: COMBINE_ONLINE # COMBINE_ONLINE, COMBINE_ONLINE_ONLY, original
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_fhdmi.ckpt
9 | dataset: TIP
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: FHDMi
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 384
54 | paired: True
55 | mode: train
56 |
57 | # test: # UHDM
58 | # target: demoire.data.uhdm.uhdm_datasets
59 | # params:
60 | # args:
61 | # dataset_path: # Please set the path to your moire pattern dataset
62 | # LOADER: default
63 | # mode: test
64 |
65 | # test: # FHDMi
66 | # target: demoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | test: # TIP
74 | target: demoire.data.tip.tip_datasets
75 | params:
76 | args:
77 | dataset_path: # Please set the path to your moire pattern dataset
78 | LOADER: default
79 | mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 150
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/esdnet/tip/cd_unidemoire_esdnet_tip.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: ESDNet
6 | mode: COMBINE_ONLINE
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt
9 | dataset: UHDM
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 8
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: TIP
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 256
54 | paired: True
55 | mode: train
56 |
57 | test: # UHDM
58 | target: unidemoire.data.uhdm.uhdm_datasets
59 | params:
60 | args:
61 | dataset_path: # Please set the path to your moire pattern dataset
62 | LOADER: default
63 | mode: test
64 |
65 | # test: # FHDMi
66 | # target: unidemoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | # test: # TIP
74 | # target: unidemoire.data.tip.tip_datasets
75 | # params:
76 | # args:
77 | # dataset_path: # Please set the path to your moire pattern dataset
78 | # LOADER: default
79 | # mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 70
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/esdnet/uhdm/cd_unidemoire_esdnet_uhdm.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: ESDNet
6 | mode: COMBINE_ONLINE
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_uhdm.ckpt
9 | dataset: TIP
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: FHDMi
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 384
54 | paired: True
55 | mode: train
56 |
57 | # test: # UHDM
58 | # target: unidemoire.data.uhdm.uhdm_datasets
59 | # params:
60 | # args:
61 | # dataset_path: # Please set the path to your moire pattern dataset
62 | # LOADER: default
63 | # mode: test
64 |
65 | # test: # FHDMi
66 | # target: unidemoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | test: # TIP
74 | target: unidemoire.data.tip.tip_datasets
75 | params:
76 | args:
77 | dataset_path: # Please set the path to your moire pattern dataset
78 | LOADER: default
79 | mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 150
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/mbcnn/fhdmi/cd_unidemoire_mbcnn_fhdmi.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model # ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | model_name: MBCNN
6 | mode: COMBINE_ONLINE # COMBINE_ONLINE, COMBINE_ONLINE_ONLY, original
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_fhdmi.ckpt
9 | dataset: UHDM
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: FHDMi
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 384
54 | paired: True
55 | mode: train
56 |
57 | test: # UHDM
58 | target: unidemoire.data.uhdm.uhdm_datasets
59 | params:
60 | args:
61 | dataset_path: # Please set the path to your moire pattern dataset
62 | LOADER: default
63 | mode: test
64 |
65 | # test: # FHDMi
66 | # target: unidemoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | # test: # TIP
74 | # target: unidemoire.data.tip.tip_datasets
75 | # params:
76 | # args:
77 | # dataset_path: # Please set the path to your moire pattern dataset
78 | # LOADER: default
79 | # mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 150
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/mbcnn/tip/cd_unidemoire_mbcnn_tip.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: demoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: MBCNN
6 | mode: COMBINE_ONLINE
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt
9 | dataset: UHDM
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 8
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: UHDM
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 256
54 | paired: True
55 | mode: train
56 |
57 | test: # UHDM
58 | target: unidemoire.data.uhdm.uhdm_datasets
59 | params:
60 | args:
61 | dataset_path: # Please set the path to your moire pattern dataset
62 | LOADER: default
63 | mode: test
64 |
65 | # test: # FHDMi
66 | # target: unidemoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | # test: # TIP
74 | # target: unidemoire.data.tip.tip_datasets
75 | # params:
76 | # args:
77 | # dataset_path: # Please set the path to your moire pattern dataset
78 | # LOADER: default
79 | # mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 70
--------------------------------------------------------------------------------
/configs/demoire/cross-dataset/mbcnn/uhdm/cd_unidemoire_mbcnn_uhdm.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: MBCNN
6 | mode: COMBINE_ONLINE
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_uhdm.ckpt
9 | dataset: TIP
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | network_config:
14 | # ESDNet
15 | en_feature_num: 48
16 | en_inter_num: 32
17 | de_feature_num: 64
18 | de_inter_num: 32
19 | sam_number: 2 # ESDNet:1, ESDNet-L:2
20 |
21 | # MBCNN
22 | n_filters: 64
23 |
24 | loss_config:
25 | # ESDNet
26 | LAM: 1
27 | LAM_P: 1
28 |
29 | optimizer_config:
30 | # ESDNet
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 4
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: UHDM
48 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
49 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
50 | tip_dataset_path: # Please set the path to your moire pattern dataset
51 | moire_pattern_path: # Please set the path to your moire pattern dataset
52 | loader: crop
53 | crop_size: 384
54 | paired: True
55 | mode: train
56 |
57 | # test: # UHDM
58 | # target: unidemoire.data.uhdm.uhdm_datasets
59 | # params:
60 | # args:
61 | # dataset_path: # Please set the path to your moire pattern dataset
62 | # LOADER: default
63 | # mode: test
64 |
65 | # test: # FHDMi
66 | # target: unidemoire.data.fhdmi.fhdmi_datasets
67 | # params:
68 | # args:
69 | # dataset_path: # Please set the path to your moire pattern dataset
70 | # LOADER: default
71 | # mode: test
72 |
73 | test: # TIP
74 | target: unidemoire.data.tip.tip_datasets
75 | params:
76 | args:
77 | dataset_path: # Please set the path to your moire pattern dataset
78 | LOADER: default
79 | mode: test
80 |
81 | lightning:
82 | callbacks:
83 | image_logger:
84 | target: main.ImageLogger
85 | params:
86 | increase_log_steps: False
87 | rescale: False
88 | batch_frequency: 500
89 | max_images: 8
90 |
91 | trainer:
92 | benchmark: True
93 | # accumulate_grad_batches: 1
94 | max_epochs: 150
--------------------------------------------------------------------------------
/configs/demoire/mhrnid/mhrnid_esdnet_unidemoire.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: ESDNet
6 | mode: use_synthetic_moire_image_only
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt
9 | dataset: UHDM
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | ckpt_path: ## for testing
14 |
15 | network_config:
16 | # ESDNet
17 | en_feature_num: 48
18 | en_inter_num: 32
19 | de_feature_num: 64
20 | de_inter_num: 32
21 | sam_number: 2 # ESDNet:1, ESDNet-L:2
22 |
23 | # MBCNN
24 | n_filters: 64
25 |
26 | loss_config:
27 | LAM: 1
28 | LAM_P: 1
29 |
30 | optimizer_config:
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 8
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: MHRNID
48 | mhrnid_dataset_path: # Please set the path to your moire pattern dataset
49 | moire_pattern_path: # Please set the path to your moire pattern dataset
50 | loader: crop
51 | crop_size: 384
52 | paired: True
53 | mode: train
54 |
55 | test: # UHDM
56 | target: unidemoire.data.uhdm.uhdm_datasets
57 | params:
58 | args:
59 | dataset_path: # Please set the path to your moire pattern dataset
60 | LOADER: default
61 | mode: test
62 |
63 | # test: # FHDMi
64 | # target: unidemoire.data.fhdmi.fhdmi_datasets
65 | # params:
66 | # args:
67 | # dataset_path: # Please set the path to your moire pattern dataset
68 | # LOADER: default
69 | # mode: test
70 |
71 |
72 | lightning:
73 | callbacks:
74 | image_logger:
75 | target: main.ImageLogger
76 | params:
77 | increase_log_steps: False
78 | rescale: False
79 | batch_frequency: 500
80 | max_images: 8
81 |
82 | trainer:
83 | benchmark: True
84 | # accumulate_grad_batches: 1
85 | max_epochs: 50
--------------------------------------------------------------------------------
/configs/demoire/mhrnid/mhrnid_mbcnn_unidemoire.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-4
3 | target: unidemoire.models.moire_nets.Demoireing_Model
4 | params:
5 | model_name: MBCNN
6 | mode: use_synthetic_moire_image_only
7 | blending_method: unidemoire # unidemoire, shooting, moirespace, undem
8 | blending_model_path: ./models/moire_blending/bl_tip.ckpt
9 | dataset: UHDM
10 | evaluation_time: False
11 | evaluation_metric: True
12 | save_img: True
13 | ckpt_path: ## for testing
14 |
15 | network_config:
16 | # ESDNet
17 | en_feature_num: 48
18 | en_inter_num: 32
19 | de_feature_num: 64
20 | de_inter_num: 32
21 | sam_number: 2 # ESDNet:1, ESDNet-L:2
22 |
23 | # MBCNN
24 | n_filters: 64
25 |
26 | loss_config:
27 | LAM: 1
28 | LAM_P: 1
29 |
30 | optimizer_config:
31 | beta1: 0.9
32 | beta2: 0.999
33 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
34 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
35 | eta_min: 0.000001
36 |
37 | data:
38 | target: main.DataModuleFromConfig
39 | params:
40 | batch_size: 4
41 | num_workers: 8
42 | wrap: True
43 | train:
44 | target: unidemoire.data.moire_blend.moire_blending_datasets
45 | params:
46 | args:
47 | natural_dataset_name: MHRNID
48 | mhrnid_dataset_path: # Please set the path to your moire pattern dataset
49 | moire_pattern_path: # Please set the path to your moire pattern dataset
50 | loader: crop
51 | crop_size: 384
52 | paired: True
53 | mode: train
54 |
55 | test: # UHDM
56 | target: unidemoire.data.uhdm.uhdm_datasets
57 | params:
58 | args:
59 | dataset_path: # Please set the path to your moire pattern dataset
60 | LOADER: default
61 | mode: test
62 |
63 | # test: # FHDMi
64 | # target: unidemoire.data.fhdmi.fhdmi_datasets
65 | # params:
66 | # args:
67 | # dataset_path: # Please set the path to your moire pattern dataset
68 | # LOADER: default
69 | # mode: test
70 |
71 |
72 | lightning:
73 | callbacks:
74 | image_logger:
75 | target: main.ImageLogger
76 | params:
77 | increase_log_steps: False
78 | rescale: False
79 | batch_frequency: 500
80 | max_images: 8
81 |
82 | trainer:
83 | benchmark: True
84 | # accumulate_grad_batches: 1
85 | max_epochs: 50
--------------------------------------------------------------------------------
/configs/latent-diffusion/ldm-vae-768-crop.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-4 # 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: unidemoire.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 24
14 | channels: 64
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: unidemoire.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: unidemoire.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 24
33 | in_channels: 64
34 | out_channels: 64
35 | model_channels: 192
36 | attention_resolutions: [1, 2, 4, 8]
37 | num_res_blocks: 2
38 | channel_mult: [1,2,2,4]
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: unidemoire.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 64
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/moire_generator/autoencoder/last.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 64
52 | resolution: 768
53 | in_channels: 3
54 | out_ch: 3
55 | wide_scale_resolution: False
56 | ch: 64
57 | ch_mult: [1,1,2,2,4,4]
58 | num_res_blocks: 2
59 | attn_resolutions: [16,8]
60 | dropout: 0.0
61 |
62 | lossconfig:
63 | target: torch.nn.Identity
64 |
65 | cond_stage_config: "__is_unconditional__"
66 |
67 | data:
68 | target: main.DataModuleFromConfig
69 | params:
70 | batch_size: 2
71 | wrap: True
72 | train:
73 | target: unidemoire.data.moire.MoirePattern
74 | params:
75 | dataset_path: # Please set the path to your moire pattern dataset
76 | resolution: 768
77 | validation:
78 | target: unidemoire.data.moire.MoirePattern
79 | params:
80 | dataset_path: # Please set the path to your moire pattern dataset
81 | resolution: 768
82 |
83 | lightning:
84 | callbacks:
85 | image_logger:
86 | target: main.ImageLogger
87 | params:
88 | batch_frequency: 1000
89 | max_images: 8
90 | increase_log_steps: False
91 |
92 |
93 | trainer:
94 | benchmark: True
95 | # precision: 16
--------------------------------------------------------------------------------
/configs/moire-blending/fhdmi/blending_fhdmi.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | # MIB
9 | bl_method_1: multiply
10 | bl_method_1_op: 1.0
11 | bl_method_2: grain_merge
12 | bl_method_2_op: 0.8
13 | bl_final_weight_min: 0.65
14 | bl_final_weight_max: 0.75
15 |
16 | blending_network_args:
17 | # TRN
18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1]
19 | embed_dim: 16
20 | win_size: 8
21 | modulator: False
22 | shift_flag: False
23 |
24 | loss_config:
25 | LAM: 1
26 | LAM_P: 1
27 |
28 | optimizer_config:
29 | beta1: 0.9
30 | beta2: 0.999
31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
33 | eta_min: 0.000001
34 |
35 | data:
36 | target: main.DataModuleFromConfig
37 | params:
38 | batch_size: 2
39 | num_workers: 8
40 | wrap: True
41 | train:
42 | target: unidemoire.data.moire_blend.moire_blending_datasets
43 | params:
44 | args:
45 | natural_dataset_name: FHDMi
46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
48 | tip_dataset_path: # Please set the path to your moire pattern dataset
49 | moire_pattern_path: # Please set the path to your moire pattern dataset
50 | loader: crop
51 | crop_size: 384
52 | paired: True
53 | mode: train
54 |
55 | lightning:
56 | callbacks:
57 | image_logger:
58 | target: main.ImageLogger
59 | params:
60 | increase_log_steps: False
61 | rescale: False
62 | batch_frequency: 500
63 | max_images: 8
64 |
65 | trainer:
66 | benchmark: True
67 | accumulate_grad_batches: 1
68 | max_epochs: 25
--------------------------------------------------------------------------------
/configs/moire-blending/tip/blending_tip.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | # MIB
9 | bl_method_1: multiply
10 | bl_method_1_op: 1.0
11 | bl_method_2: grain_merge
12 | bl_method_2_op: 0.8
13 | bl_final_weight_min: 0.65
14 | bl_final_weight_max: 0.75
15 |
16 | blending_network_args:
17 | # TRN
18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1]
19 | embed_dim: 16
20 | win_size: 8
21 | modulator: False
22 | shift_flag: False
23 |
24 | loss_config:
25 | LAM: 1
26 | LAM_P: 1
27 |
28 | optimizer_config:
29 | beta1: 0.9
30 | beta2: 0.999
31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
33 | eta_min: 0.000001
34 |
35 | data:
36 | target: main.DataModuleFromConfig
37 | params:
38 | batch_size: 2
39 | num_workers: 4
40 | wrap: True
41 | train:
42 | target: unidemoire.data.moire_blend.moire_blending_datasets
43 | params:
44 | args:
45 | natural_dataset_name: TIP
46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
48 | tip_dataset_path: # Please set the path to your moire pattern dataset
49 | moire_pattern_path: # Please set the path to your moire pattern dataset
50 | loader: crop
51 | crop_size: 256
52 | paired: True
53 | mode: train
54 |
55 | lightning:
56 | callbacks:
57 | image_logger:
58 | target: main.ImageLogger
59 | params:
60 | increase_log_steps: False
61 | rescale: False
62 | batch_frequency: 500
63 | max_images: 8
64 |
65 | trainer:
66 | benchmark: True
67 | accumulate_grad_batches: 1
68 | max_epochs: 2
--------------------------------------------------------------------------------
/configs/moire-blending/uhdm/blending_uhdm.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | # MIB
9 | bl_method_1: multiply
10 | bl_method_1_op: 1.0
11 | bl_method_2: grain_merge
12 | bl_method_2_op: 0.8
13 | bl_final_weight_min: 0.65
14 | bl_final_weight_max: 0.75
15 |
16 | blending_network_args:
17 | # TRN
18 | depths: [1, 1, 1, 1, 1, 1, 1, 1, 1]
19 | embed_dim: 16
20 | win_size: 8
21 | modulator: False
22 | shift_flag: False
23 |
24 | loss_config:
25 | LAM: 1
26 | LAM_P: 1
27 |
28 | optimizer_config:
29 | beta1: 0.9
30 | beta2: 0.999
31 | T_0: 50 # The total epochs for the first learning cycle (learning rate warms up then)
32 | T_mult: 1 # The learning cycle would be (T_0, T_0*T_MULT, T_0*T_MULT^2, T_0*T_MULT^3, ...)
33 | eta_min: 0.000001
34 |
35 | data:
36 | target: main.DataModuleFromConfig
37 | params:
38 | batch_size: 2
39 | num_workers: 4
40 | wrap: True
41 | train:
42 | target: unidemoire.data.moire_blend.moire_blending_datasets
43 | params:
44 | args:
45 | natural_dataset_name: UHDM
46 | uhdm_dataset_path: # Please set the path to your moire pattern dataset
47 | fhdmi_dataset_path: # Please set the path to your moire pattern dataset
48 | tip_dataset_path: # Please set the path to your moire pattern dataset
49 | moire_pattern_path: # Please set the path to your moire pattern dataset
50 | loader: crop
51 | crop_size: 384
52 | paired: True
53 | mode: train
54 |
55 | lightning:
56 | callbacks:
57 | image_logger:
58 | target: main.ImageLogger
59 | params:
60 | increase_log_steps: False
61 | rescale: False
62 | batch_frequency: 500
63 | max_images: 8
64 |
65 | trainer:
66 | benchmark: True
67 | accumulate_grad_batches: 1
68 | max_epochs: 50
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: unidemoire
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.0
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.4.2
19 | - omegaconf==2.1.1
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - torch-fidelity==0.3.0
24 | - transformers==4.3.1
25 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
26 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
27 | - -e .
28 | - colour-demosaicing==0.2.2
29 | - thop==0.1.1-2209072238
30 | - lpips==0.1.4
31 | - timm==0.9.16
32 | - pillow==9.5.0
33 |
--------------------------------------------------------------------------------
/models/moire_blending/fhdmi/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.00001
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | bl_method_1: multiply
9 | bl_method_1_op: 1.0
10 | bl_method_2: grain_merge
11 | bl_method_2_op: 0.8
12 | bl_final_weight_min: 0.65
13 | bl_final_weight_max: 0.75
14 | blending_network_args:
15 | depths: [1,1,1,1,1,1,1,1,1]
16 | embed_dim: 16
17 | win_size: 8
18 | modulator: true
19 | shift_flag: false
20 |
--------------------------------------------------------------------------------
/models/moire_blending/tip/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.00001
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | bl_method_1: multiply
9 | bl_method_1_op: 1.0
10 | bl_method_2: grain_merge
11 | bl_method_2_op: 0.8
12 | bl_final_weight_min: 0.65
13 | bl_final_weight_max: 0.75
14 | blending_network_args:
15 | depths: [1,1,1,1,1,1,1,1,1]
16 | embed_dim: 16
17 | win_size: 8
18 | modulator: true
19 | shift_flag: false
20 |
--------------------------------------------------------------------------------
/models/moire_blending/uhdm/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.00001
3 | target: unidemoire.models.moire_blending.MoireBlending_Model
4 | params:
5 | model_name: UniDemoire
6 | network_config:
7 | init_blending_args:
8 | bl_method_1: multiply
9 | bl_method_1_op: 1.0
10 | bl_method_2: grain_merge
11 | bl_method_2_op: 0.8
12 | bl_final_weight_min: 0.65
13 | bl_final_weight_max: 0.75
14 | blending_network_args:
15 | depths: [1,1,1,1,1,1,1,1,1]
16 | embed_dim: 16
17 | win_size: 8
18 | modulator: true
19 | shift_flag: false
20 |
--------------------------------------------------------------------------------
/models/moire_generator/diffusion/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: unidemoire.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 24
14 | # wide_scale_resolution: false
15 | channels: 64
16 | cond_stage_trainable: false
17 | concat_mode: false
18 | scale_by_std: true
19 | monitor: val/loss_simple_ema
20 | scheduler_config:
21 | target: unidemoire.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps:
24 | - 10000
25 | cycle_lengths:
26 | - 10000000000000
27 | f_start:
28 | - 1.0e-06
29 | f_max:
30 | - 1.0
31 | f_min:
32 | - 1.0
33 | unet_config:
34 | target: unidemoire.modules.diffusionmodules.openaimodel.UNetModel
35 | params:
36 | image_size: 24
37 | in_channels: 64
38 | out_channels: 64
39 | model_channels: 192
40 | attention_resolutions:
41 | - 1
42 | - 2
43 | - 4
44 | - 8
45 | num_res_blocks: 2
46 | channel_mult:
47 | - 1
48 | - 2
49 | - 2
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: unidemoire.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 64
58 | monitor: val/rec_loss
59 | # VAE 模型路径
60 | ckpt_path: models/moire_generator/autoencoder/last.ckpt
61 | ddconfig:
62 | double_z: true
63 | z_channels: 64
64 | resolution: 768
65 | in_channels: 3
66 | out_ch: 3
67 | # wide_scale_resolution: false
68 | ch: 64
69 | ch_mult:
70 | - 1
71 | - 1
72 | - 2
73 | - 2
74 | - 4
75 | - 4
76 | num_res_blocks: 2
77 | attn_resolutions:
78 | - 16
79 | - 8
80 | dropout: 0.0
81 | lossconfig:
82 | target: torch.nn.Identity
83 | cond_stage_config: __is_unconditional__
84 | data:
85 | target: main.DataModuleFromConfig
86 | params:
87 | batch_size: 2
88 | wrap: true
89 | train:
90 | target: unidemoire.data.moire.MoirePattern
91 | params:
92 | dataset_path: "/inspurfs/group/mayuexin/yangzemin/data/captured_data"
93 | resolution: 768
94 | validation:
95 | target: unidemoire.data.moire.MoirePattern
96 | params:
97 | dataset_path: "/inspurfs/group/mayuexin/yangzemin/data/captured_data"
98 | resolution: 768
99 |
--------------------------------------------------------------------------------
/scripts/sample_moire_pattern.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob, datetime, yaml
2 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
3 |
4 | import torch
5 | import time
6 | import numpy as np
7 | from tqdm import trange
8 |
9 | from omegaconf import OmegaConf
10 | from PIL import Image
11 |
12 |
13 | from unidemoire.models.diffusion.ddim import DDIMSampler
14 | from unidemoire.util import instantiate_from_config
15 |
16 | rescale = lambda x: (x + 1.) / 2.
17 |
18 | def custom_to_pil(x):
19 | x = x.detach().cpu()
20 | x = torch.clamp(x, -1., 1.)
21 | x = (x + 1.) / 2.
22 | x = x.permute(1, 2, 0).numpy()
23 | x = (255 * x).astype(np.uint8)
24 | if x.shape[2] == 1:
25 | x = x.squeeze()
26 | x = Image.fromarray(x, mode='L')
27 | else:
28 | x = Image.fromarray(x)
29 | x = x.convert("RGB")
30 | return x
31 |
32 | def custom_to_np(x):
33 | # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
34 | sample = x.detach().cpu()
35 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
36 | sample = sample.permute(0, 2, 3, 1)
37 | sample = sample.contiguous()
38 | return sample
39 |
40 | def logs2pil(logs, keys=["sample"]):
41 | imgs = dict()
42 | for k in logs:
43 | try:
44 | if len(logs[k].shape) == 4:
45 | img = custom_to_pil(logs[k][0, ...])
46 | elif len(logs[k].shape) == 3:
47 | img = custom_to_pil(logs[k])
48 | else:
49 | print(f"Unknown format for key {k}. ")
50 | img = None
51 | except:
52 | img = None
53 | imgs[k] = img
54 | return imgs
55 |
56 | @torch.no_grad()
57 | def convsample(model, shape, return_intermediates=True,
58 | verbose=True,
59 | make_prog_row=False):
60 | if not make_prog_row:
61 | return model.p_sample_loop(None, shape,
62 | return_intermediates=return_intermediates, verbose=verbose)
63 | else:
64 | return model.progressive_denoising(
65 | None, shape, verbose=True
66 | )
67 |
68 | @torch.no_grad()
69 | def convsample_ddim(model, steps, shape, eta=1.0):
70 | ddim = DDIMSampler(model)
71 | bs = shape[0]
72 | shape = shape[1:]
73 | samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
74 | return samples, intermediates
75 |
76 | @torch.no_grad()
77 | def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
78 |
79 | log = dict()
80 | shape = [batch_size,
81 | model.model.diffusion_model.in_channels,
82 | model.model.diffusion_model.image_size,
83 | model.model.diffusion_model.image_size]
84 |
85 | with model.ema_scope("Plotting"):
86 | t0 = time.time()
87 | if vanilla:
88 | sample, progrow = convsample(model, shape,
89 | make_prog_row=True)
90 | else:
91 | sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
92 | eta=eta)
93 | t1 = time.time()
94 |
95 | x_sample = model.decode_first_stage(sample)
96 | log["sample"] = x_sample
97 | log["time"] = t1 - t0
98 | log['throughput'] = sample.shape[0] / (t1 - t0)
99 | return log
100 |
101 | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
102 | if vanilla:
103 | print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
104 | else: # Using DDIM sampling with 200 sampling steps and eta=1.0
105 | print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
106 |
107 | tstart = time.time()
108 | n_saved = len(glob.glob(os.path.join(logdir,'*.png')))
109 |
110 | if model.cond_stage_model is None:
111 | all_images = []
112 | print(f"Running unconditional sampling for {n_samples} samples")
113 | for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
114 | if n_saved >= n_samples:
115 | print(f'Finish after generating {n_saved} samples')
116 | break
117 | logs = make_convolutional_sample(model, batch_size=batch_size,
118 | vanilla=vanilla, custom_steps=custom_steps,
119 | eta=eta)
120 | n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
121 | all_images.extend([custom_to_np(logs["sample"])])
122 |
123 | else:
124 | raise NotImplementedError('Currently only sampling for unconditional models supported.')
125 |
126 | print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
127 |
128 | def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
129 | for k in logs:
130 | if k == key:
131 | batch = logs[key]
132 | if np_path is None:
133 | for x in batch:
134 | img = custom_to_pil(x)
135 | imgpath = os.path.join(path, f"{n_saved:07}.png")
136 | img.save(imgpath)
137 | n_saved += 1
138 | else:
139 | npbatch = custom_to_np(batch)
140 | shape_str = "x".join([str(x) for x in npbatch.shape])
141 | nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
142 | np.savez(nppath, npbatch)
143 | n_saved += npbatch.shape[0]
144 | return n_saved
145 |
146 | def get_parser():
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument(
149 | "-r",
150 | "--resume",
151 | type=str,
152 | nargs="?",
153 | help="load from logdir or checkpoint in logdir",
154 | default="./data/generated"
155 | )
156 | parser.add_argument(
157 | "-n",
158 | "--n_samples",
159 | type=int,
160 | nargs="?",
161 | help="number of samples to draw",
162 | default=20
163 | )
164 | parser.add_argument(
165 | "-e",
166 | "--eta",
167 | type=float,
168 | nargs="?",
169 | help="eta for ddim sampling (0.0 yields deterministic sampling)",
170 | default=1.0
171 | )
172 | parser.add_argument(
173 | "-v",
174 | "--vanilla_sample",
175 | default=False,
176 | action='store_true',
177 | help="vanilla sampling (default option is DDIM sampling)?",
178 | )
179 | parser.add_argument(
180 | "-l",
181 | "--logdir",
182 | type=str,
183 | nargs="?",
184 | help="extra logdir",
185 | default="./data/generated"
186 | )
187 | parser.add_argument(
188 | "-c",
189 | "--custom_steps",
190 | type=int,
191 | nargs="?",
192 | help="number of steps for ddim and fastdpm sampling",
193 | default=200
194 | )
195 | parser.add_argument(
196 | "--batch_size",
197 | type=int,
198 | nargs="?",
199 | help="the bs",
200 | default=1
201 | )
202 | return parser
203 |
204 | def load_model_from_config(config, sd):
205 | model = instantiate_from_config(config)
206 | model.load_state_dict(sd,strict=False)
207 | model.cuda()
208 | model.eval()
209 | return model
210 |
211 | def load_model(config, ckpt, gpu, eval_mode):
212 | if ckpt:
213 | print(f"Loading model from {ckpt}")
214 | pl_sd = torch.load(ckpt, map_location="cpu")
215 | global_step = pl_sd["global_step"]
216 | else:
217 | pl_sd = {"state_dict": None}
218 | global_step = None
219 | model = load_model_from_config(config.model,
220 | pl_sd["state_dict"])
221 |
222 | return model, global_step
223 |
224 |
225 | if __name__ == "__main__":
226 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
227 | sys.path.append(os.getcwd())
228 | command = " ".join(sys.argv)
229 | print(75 * "=")
230 | parser = get_parser()
231 | opt, unknown = parser.parse_known_args()
232 | ckpt = None
233 |
234 | if not os.path.exists(opt.resume):
235 | raise ValueError("Cannot find {}".format(opt.resume))
236 |
237 | if os.path.isfile(opt.resume):
238 | try:
239 | logdir = '/'.join(opt.resume.split('/')[:-1])
240 | print(f'Logdir is {logdir}')
241 | except ValueError:
242 | paths = opt.resume.split("/")
243 | idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
244 | logdir = "/".join(paths[:idx])
245 | ckpt = opt.resume
246 | else:
247 | assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
248 | logdir = opt.resume.rstrip("/")
249 | ckpt = os.path.join(logdir, "model.ckpt")
250 |
251 | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
252 | opt.base = base_configs
253 |
254 | configs = [OmegaConf.load(cfg) for cfg in opt.base]
255 | cli = OmegaConf.from_dotlist(unknown)
256 | config = OmegaConf.merge(*configs, cli)
257 |
258 | gpu = True
259 | eval_mode = True
260 |
261 | if opt.logdir != "none":
262 | locallog = logdir.split(os.sep)[-1]
263 | if locallog == "": locallog = logdir.split(os.sep)[-2]
264 | print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
265 | logdir = os.path.join(opt.logdir, locallog)
266 |
267 | model, global_step = load_model(config, ckpt, gpu, eval_mode)
268 | print(f"global step: {global_step}")
269 | print(75 * "=")
270 | print("logging to:")
271 | logdir = os.path.join(logdir, now)
272 | imglogdir = os.path.join(logdir, "moire_patterns")
273 | os.makedirs(imglogdir)
274 | print(logdir)
275 | print(75 * "=")
276 |
277 | # write config out
278 | sampling_file = os.path.join(logdir, "sampling_config.yaml")
279 | sampling_conf = vars(opt)
280 |
281 | with open(sampling_file, 'w') as f:
282 | yaml.dump(sampling_conf, f, default_flow_style=False)
283 |
284 | run(model, imglogdir, eta=opt.eta,
285 | vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
286 | batch_size=opt.batch_size)
287 |
288 | print("done.")
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='unidemoire',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------
/static/images/Pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/static/images/Pipeline.png
--------------------------------------------------------------------------------
/taming/modules/autoencoder/lpips/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/taming/modules/autoencoder/lpips/vgg.pth
--------------------------------------------------------------------------------
/unidemoire/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/__init__.py
--------------------------------------------------------------------------------
/unidemoire/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/data/__init__.py
--------------------------------------------------------------------------------
/unidemoire/data/fhdmi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | import cv2
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image
9 | from PIL import ImageFile
10 | import os
11 |
12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader
13 |
14 |
15 | class fhdmi_datasets(data.Dataset):
16 | def __init__(self, args, mode='train'):
17 | self.args = args
18 | self.mode = mode
19 | self.loader = args["LOADER"]
20 | self.image_list = sorted([file for file in os.listdir(self.args["dataset_path"] + '/target') if file.endswith('.png')])
21 |
22 | def __getitem__(self, index):
23 | ImageFile.LOAD_TRUNCATED_IMAGES = True
24 | data = {}
25 | image_in_gt = self.image_list[index]
26 | number = image_in_gt[4:9]
27 | image_in = 'src_' + number + '.png'
28 | if self.mode == 'train':
29 | path_tar = self.args["dataset_path"] + '/target/' + image_in_gt
30 | path_src = self.args["dataset_path"] + '/source/' + image_in
31 | if self.loader == 'crop':
32 | x = random.randint(0, 1920 - self.args["CROP_SIZE"])
33 | y = random.randint(0, 1080 - self.args["CROP_SIZE"])
34 | labels, moire_imgs = crop_loader(self.args["CROP_SIZE"], x, y, [path_tar, path_src])
35 |
36 | elif self.loader == 'resize':
37 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src])
38 | data['origin_label'] = default_loader([path_tar])[0]
39 |
40 | elif self.loader == 'default':
41 | labels, moire_imgs = default_loader([path_tar, path_src])
42 |
43 | elif self.mode == 'test':
44 | path_tar = self.args["dataset_path"] + '/target/' + image_in_gt
45 | path_src = self.args["dataset_path"] + '/source/' + image_in
46 | if self.loader == 'resize':
47 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src])
48 | data['origin_label'] = default_loader([path_tar])[0]
49 | else:
50 | labels, moire_imgs = default_loader([path_tar, path_src])
51 |
52 | else:
53 | print('Unrecognized mode! Please select either "train" or "test"')
54 | raise NotImplementedError
55 |
56 | data['in_img'] = moire_imgs
57 | data['label'] = labels
58 | data['number'] = number
59 | data['mode'] = self.mode
60 | return data
61 |
62 | def __len__(self):
63 | return len(self.image_list)
--------------------------------------------------------------------------------
/unidemoire/data/moire.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from PIL import Image, ImageFilter, ImageEnhance
5 |
6 | import numpy as np
7 | from torch.utils.data import Dataset
8 | import torchvision.transforms as transforms
9 |
10 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
12 |
13 | def is_image_file(filename):
14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15 |
16 | def get_paths_from_images(path):
17 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
18 | images = []
19 | for dirpath, _, fnames in sorted(os.walk(path)):
20 | for fname in sorted(fnames):
21 | if is_image_file(fname):
22 | img_path = os.path.join(dirpath, fname)
23 | images.append(img_path)
24 | assert images, '{:s} has no valid image file'.format(path)
25 | return sorted(images)
26 |
27 |
28 | class MoirePattern(Dataset):
29 | def __init__(self, dataset_path, resolution):
30 | self.resolution = resolution
31 | self.pil_to_tensor = transforms.ToTensor()
32 | self.dataset_path = dataset_path
33 | self.moire_layer_path = get_paths_from_images(self.dataset_path)
34 |
35 | def __len__(self):
36 | return len(self.moire_layer_path)
37 |
38 | def calculate_sharpness(self, image):
39 | image_gray = image.convert('L')
40 | image_laplace = image_gray.filter(ImageFilter.FIND_EDGES)
41 | sharpness = np.std(np.array(image_laplace))
42 | return sharpness
43 |
44 | def calculate_colorfulness(self, image):
45 | image_lab = image.convert('LAB')
46 | l, a, b = image_lab.split()
47 | std_a = np.std(np.array(a))
48 | std_b = np.std(np.array(b))
49 | colorfulness = np.sqrt(std_a ** 2 + std_b ** 2)
50 | return colorfulness
51 |
52 | def calculate_image_quality(self, image):
53 | sharpness = self.calculate_sharpness(image)
54 | colorfulness = self.calculate_colorfulness(image)
55 | return sharpness, colorfulness
56 |
57 | def __getitem__(self, index):
58 | while(True):
59 | ## TODO: try different index moire patterns
60 | for i in range(3):
61 | ## TODO: [Multi crop] + [Sharpness & Colorfulness selection]
62 | img_moire_layer = Image.open(self.moire_layer_path[index])
63 | self.transform_init()
64 | img_moire_layer = self.transform(img_moire_layer)
65 | sharpness, colorfulness = self.calculate_image_quality(img_moire_layer)
66 | if sharpness < 15 or colorfulness < 2.0:
67 | continue
68 | else:
69 | img_moire_layer = ImageEnhance.Contrast(img_moire_layer).enhance(2.0)
70 | img_moire_layer = self.pil_to_tensor(img_moire_layer)
71 | return { "image": img_moire_layer }
72 | index = random.randint(0, len(self.moire_layer_path) - 1)
73 |
74 | def transform_init(self):
75 | w = h = self.resolution
76 | base_transforms = [transforms.RandomHorizontalFlip(p=0.5),]
77 |
78 | q = random.randint(0, 2)
79 | r = random.randint(0, 1)
80 | if r == 0: # 4K crop into (w, h)
81 | extra_transforms = [transforms.RandomCrop(size=(h, w))]
82 | elif q == 0: # 4K to 2K, then crop into (w, h)
83 | extra_transforms = [transforms.Resize(size=(1440, 2560)), transforms.RandomCrop(size=(h, w))]
84 | elif q == 1: # 4K to 1080P, then crop into (w, h)
85 | extra_transforms = [transforms.Resize(size=(1080, 1920)), transforms.RandomCrop(size=(h, w))]
86 | elif q == 2: # 4K resize into (w, h)
87 | extra_transforms = [transforms.Resize(size=(h, w))]
88 |
89 | tran_transform = transforms.Compose(extra_transforms + base_transforms)
90 | # test_transform = transforms.Compose([transforms.Resize((h, w))] + base_transforms)
91 | self.transform = tran_transform
--------------------------------------------------------------------------------
/unidemoire/data/tip.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | import cv2
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image
9 | from PIL import ImageFile
10 | import os
11 |
12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader
13 |
14 | class tip_datasets(data.Dataset):
15 |
16 | def __init__(self, args, mode='train'):
17 |
18 | data_path = args['dataset_path']
19 | image_list = sorted([file for file in os.listdir(data_path + '/source') if file.endswith('.png')])
20 | self.image_list = image_list
21 | self.args = args
22 | self.mode = mode
23 | t_list = [transforms.ToTensor()]
24 | self.composed_transform = transforms.Compose(t_list)
25 |
26 | def default_loader(self, path):
27 | return Image.open(path).convert('RGB')
28 |
29 | def __getitem__(self, index):
30 | ImageFile.LOAD_TRUNCATED_IMAGES = True
31 | data = {}
32 | image_in = self.image_list[index]
33 | image_in_gt = image_in[:-10] + 'target.png'
34 | number = image_in_gt[:-11]
35 |
36 | if self.mode == 'train':
37 | labels = self.default_loader(self.args['dataset_path'] + '/target/' + image_in_gt)
38 | moire_imgs = self.default_loader(self.args['dataset_path'] + '/source/' + image_in)
39 |
40 | w, h = labels.size
41 | i = random.randint(-6, 6)
42 | j = random.randint(-6, 6)
43 | labels = labels.crop((int(w / 6) + i, int(h / 6) + j, int(w * 5 / 6) + i, int(h * 5 / 6) + j))
44 | moire_imgs = moire_imgs.crop((int(w / 6) + i, int(h / 6) + j, int(w * 5 / 6) + i, int(h * 5 / 6) + j))
45 |
46 | labels = labels.resize((256, 256), Image.BILINEAR)
47 | moire_imgs = moire_imgs.resize((256, 256), Image.BILINEAR)
48 |
49 | elif self.mode == 'test':
50 | labels = self.default_loader(self.args['dataset_path'] + '/target/' + image_in_gt)
51 | moire_imgs = self.default_loader(self.args['dataset_path'] + '/source/' + image_in)
52 |
53 | w, h = labels.size
54 | labels = labels.crop((int(w / 6), int(h / 6), int(w * 5 / 6), int(h * 5 / 6)))
55 | moire_imgs = moire_imgs.crop((int(w / 6), int(h / 6), int(w * 5 / 6), int(h * 5 / 6)))
56 |
57 | labels = labels.resize((256, 256), Image.BILINEAR)
58 | moire_imgs = moire_imgs.resize((256, 256), Image.BILINEAR)
59 |
60 |
61 | else:
62 | print('Unrecognized mode! Please select either "train" or "test"')
63 | raise NotImplementedError
64 |
65 | moire_imgs = self.composed_transform(moire_imgs)
66 | labels = self.composed_transform(labels)
67 |
68 | data['in_img'] = moire_imgs
69 | data['label'] = labels
70 | data['number'] = number
71 | data['mode'] = self.mode
72 | return data
73 |
74 | def __len__(self):
75 | return len(self.image_list)
--------------------------------------------------------------------------------
/unidemoire/data/uhdm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | import cv2
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image
9 | from PIL import ImageFile
10 | import os
11 |
12 | from .utils import default_loader, crop_loader, resize_loader, resize_then_crop_loader
13 |
14 |
15 | class uhdm_datasets(data.Dataset):
16 |
17 | def __init__(self, args, mode='train'):
18 | self.args = args
19 | self.mode = mode
20 | self.loader = args["LOADER"]
21 | self.image_list = self._list_image_files_recursively(data_dir=self.args["dataset_path"])
22 |
23 | def _list_image_files_recursively(self, data_dir):
24 | file_list = []
25 | for home, dirs, files in os.walk(data_dir):
26 | for filename in files:
27 | if filename.endswith('gt.jpg'):
28 | file_list.append(os.path.join(home, filename))
29 | file_list.sort()
30 | return file_list
31 |
32 | def __getitem__(self, index):
33 | ImageFile.LOAD_TRUNCATED_IMAGES = True
34 | data = {}
35 | path_tar = self.image_list[index]
36 | number = os.path.split(path_tar)[-1][0:4]
37 | path_src = os.path.split(path_tar)[0] + '/' + os.path.split(path_tar)[-1][0:4] + '_moire.jpg'
38 | if self.mode == 'train':
39 | if self.loader == 'crop':
40 | if os.path.split(path_tar)[0][-5:-3] == 'mi':
41 | w = 4624
42 | h = 3472
43 | else:
44 | w = 4032
45 | h = 3024
46 | x = random.randint(0, w - self.args["CROP_SIZE"])
47 | y = random.randint(0, h - self.args["CROP_SIZE"])
48 | labels, moire_imgs = crop_loader(self.args["CROP_SIZE"], x, y, [path_tar, path_src])
49 |
50 | elif self.loader == 'resize_then_crop':
51 | labels, moire_imgs = resize_then_crop_loader(self.args["CROP_SIZE"], self.args["RESIZE_SIZE"], [path_tar, path_src])
52 | data['origin_label'] = default_loader([path_tar])[0]
53 |
54 | elif self.loader == 'resize':
55 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src])
56 | data['origin_label'] = default_loader([path_tar])[0]
57 |
58 | elif self.loader == 'default':
59 | labels, moire_imgs = default_loader([path_tar, path_src])
60 |
61 | elif self.mode == 'test':
62 | if self.loader == 'resize':
63 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src])
64 | data['origin_label'] = default_loader([path_tar])[0]
65 | elif self.loader == 'resize_then_crop':
66 | labels, moire_imgs = resize_loader(self.args["RESIZE_SIZE"], [path_tar, path_src])
67 | data['origin_label'] = default_loader([path_tar])[0]
68 | else:
69 | labels, moire_imgs = default_loader([path_tar, path_src])
70 |
71 | else:
72 | print('Unrecognized mode! Please select either "train" or "test"')
73 | raise NotImplementedError
74 |
75 | data['in_img'] = moire_imgs
76 | data['label'] = labels
77 | data['number'] = number
78 |
79 | data['mode'] = self.mode
80 |
81 | return data
82 |
83 | def __len__(self):
84 | # return 10 # debug
85 | return len(self.image_list)
--------------------------------------------------------------------------------
/unidemoire/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from PIL import Image
4 | import torchvision.transforms as transforms
5 |
6 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
8 |
9 | def is_image_file(filename):
10 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
11 |
12 | def get_natural_image_list_and_moire_pattern_list(args, mode='train', add_clean_percent=1.0):
13 | moire_pattern_files = _list_moire_pattern_files_recursively(data_dir=args["moire_pattern_path"])
14 | if args.natural_dataset_name == 'UHDM':
15 | uhdm_natural_files = _list_image_files_recursively(data_dir=args["uhdm_dataset_path"])
16 | return uhdm_natural_files, moire_pattern_files
17 |
18 | elif args.natural_dataset_name == 'FHDMi':
19 | fhdmi_natural_files = sorted([file for file in os.listdir(args["fhdmi_dataset_path"] + '/target') if file.endswith('.png')])
20 | return fhdmi_natural_files, moire_pattern_files
21 |
22 | elif args.natural_dataset_name == 'TIP':
23 | tip_natural_files = sorted([file for file in os.listdir(args["tip_dataset_path"] + '/source') if file.endswith('.png')])
24 | return tip_natural_files, moire_pattern_files
25 |
26 | elif args.natural_dataset_name == 'AIM':
27 | if mode=='train':
28 | aim_natural_files = sorted([file for file in os.listdir(args["aim_dataset_path"] + '/moire') if file.endswith('.jpg')])
29 | else:
30 | aim_natural_files = sorted([file for file in os.listdir(args["aim_dataset_path"] + '/moire') if file.endswith('.png')])
31 | return aim_natural_files, moire_pattern_files
32 |
33 | elif args.natural_dataset_name == 'MHRNID':
34 | mhrnid_files = get_paths_from_images(path=args["mhrnid_dataset_path"])
35 | return mhrnid_files, moire_pattern_files
36 |
37 | elif args.natural_dataset_name == 'UHDM and FHDMi':
38 | uhdm_natural_files = _list_image_files_recursively(data_dir=args["uhdm_dataset_path"])
39 | fhdmi_natural_files = sorted([file for file in os.listdir(args["fhdmi_dataset_path"] + '/target') if file.endswith('.png')])
40 |
41 | print(f'Clean image percentage: {add_clean_percent*100}%')
42 | fhdmi_size = len(fhdmi_natural_files)
43 | fhdmi_sublist_size = int(fhdmi_size * add_clean_percent)
44 | fhdmi_sublist_files = fhdmi_natural_files[:fhdmi_sublist_size]
45 |
46 | return uhdm_natural_files + fhdmi_sublist_files, moire_pattern_files
47 |
48 | else:
49 | print('Unrecognized data_type!')
50 | raise NotImplementedError
51 |
52 |
53 | def get_unpaired_moire_images(args):
54 | if args.unpaired_real_moire_dataset == 'TIP':
55 | tip_real_moire_files = sorted([file for file in os.listdir(args["tip_dataset_path"] + '/source') if file.endswith('.png')])
56 | return tip_real_moire_files
57 | else:
58 | print('Unrecognized data_type!')
59 | raise NotImplementedError
60 |
61 |
62 | def _list_image_files_recursively(data_dir):
63 | file_list = []
64 | for home, dirs, files in os.walk(data_dir):
65 | for filename in files:
66 | if filename.endswith('gt.jpg'):
67 | file_list.append(os.path.join(home, filename))
68 | file_list.sort()
69 | return file_list
70 |
71 | def get_paths_from_images(path):
72 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
73 | images = []
74 | for dirpath, _, fnames in sorted(os.walk(path)):
75 | for fname in sorted(fnames):
76 | if is_image_file(fname):
77 | img_path = os.path.join(dirpath, fname)
78 | images.append(img_path)
79 | assert images, '{:s} has no valid image file'.format(path)
80 | return sorted(images)
81 |
82 | def _list_moire_pattern_files_recursively(data_dir):
83 | assert os.path.isdir(data_dir), '{:s} is not a valid directory'.format(data_dir)
84 | images = []
85 | for dirpath, _, fnames in sorted(os.walk(data_dir)):
86 | for fname in sorted(fnames):
87 | if is_image_file(fname):
88 | img_path = os.path.join(dirpath, fname)
89 | images.append(img_path)
90 | assert images, '{:s} has no valid image file'.format(data_dir)
91 | return sorted(images)
92 |
93 |
94 |
95 | def default_loader(path_set=[]):
96 | imgs = []
97 | for path in path_set:
98 | img = Image.open(path).convert('RGB')
99 | img = default_toTensor(img)
100 | imgs.append(img)
101 |
102 | return imgs
103 |
104 | def crop_loader(crop_size, x, y, path_set=[]):
105 | imgs = []
106 | for path in path_set:
107 | img = Image.open(path).convert('RGB')
108 | img = img.crop((x, y, x + crop_size, y + crop_size))
109 | img = default_toTensor(img)
110 | imgs.append(img)
111 | return imgs
112 |
113 | def resize_loader(resize_size, path_set=[]):
114 | imgs = []
115 | for path in path_set:
116 | img = Image.open(path).convert('RGB')
117 | img = img.resize((resize_size,resize_size),Image.BICUBIC)
118 | img = default_toTensor(img)
119 | imgs.append(img)
120 |
121 | return imgs
122 |
123 | def resize_then_crop_loader(crop_size, resize_size, path_set=[]):
124 | imgs = []
125 | for path in path_set:
126 | img = Image.open(path).convert('RGB')
127 | if resize_size == 1920:
128 | img = img.resize((1920,1080),Image.BICUBIC)
129 | x = random.randint(0, 1920 - crop_size)
130 | y = random.randint(0, 1080 - crop_size)
131 | else:
132 | img = img.resize((resize_size,resize_size),Image.BICUBIC)
133 | x = random.randint(0, resize_size - crop_size)
134 | y = random.randint(0, resize_size - crop_size)
135 | img = img.crop((x, y, x + crop_size, y + crop_size))
136 | img = default_toTensor(img)
137 | imgs.append(img)
138 | return imgs
139 |
140 |
141 | def default_toTensor(img):
142 | t_list = [transforms.ToTensor()]
143 | composed_transform = transforms.Compose(t_list)
144 | return composed_transform(img)
--------------------------------------------------------------------------------
/unidemoire/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/unidemoire/models/MIB/Blending.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | import torch
4 | import torch.nn as nn
5 |
6 | class Blending(nn.Module):
7 | def __init__(self, args):
8 | super(Blending, self).__init__()
9 | self.args = args
10 | self.final_weight_range = (self.args["bl_final_weight_min"], self.args["bl_final_weight_max"])
11 | self.bl_method_1_op = torch.Tensor([self.args["bl_method_1_op"]])
12 | self.bl_method_2_op = torch.Tensor([self.args["bl_method_2_op"]])
13 |
14 | def forward(self, img_background, img_foreground):
15 | # bs,c,h,w = img_background.shape
16 | self.device = img_background.device
17 |
18 | self.img_background = self.RGB_to_RGBA(img_background)
19 | self.img_foreground = self.RGB_to_RGBA(img_foreground)
20 |
21 | img_result_1 = self.get_blending_result(method=self.args["bl_method_1"], opacity=self.bl_method_1_op)
22 | img_result_2 = self.get_blending_result(method=self.args["bl_method_2"], opacity=self.bl_method_2_op)
23 | self.weight = torch.Tensor([random.uniform(*self.final_weight_range)]).to(self.device)
24 | result = img_result_1 * self.weight + img_result_2 * (1 - self.weight)
25 |
26 | return result, self.weight
27 |
28 | def init_from_ckpt(self, path, ignore_keys=list()):
29 | sd = torch.load(path, map_location="cpu")["state_dict"]
30 | keys = list(sd.keys())
31 | for k in keys:
32 | for ik in ignore_keys:
33 | if k.startswith(ik):
34 | print("Deleting key {} from state_dict.".format(k))
35 | del sd[k]
36 | self.load_state_dict(sd, strict=False)
37 | print(f"MIB Module Restored from {path}, weight = {self.mib_weight}")
38 |
39 | def RGBA_to_RGB(self, image):
40 | return image[:,:3,:,:]
41 |
42 | def RGB_to_RGBA(self, image):
43 | b, c, w, h = image.shape
44 | img = torch.ones([b, c + 1, w, h]).to(self.device)
45 | img[:,:3,:,:] = image
46 |
47 | return img
48 |
49 | def soft_light(self):
50 | """
51 | if A ≤ 0.5: C = (2A-1)(B-B^2) + B
52 | if A > 0.5: C = (2A-1)(sqrt(B)-B) + B
53 | """
54 | A = self.img_foreground[:, :3, :, :]
55 | B = self.img_background[:, :3, :, :]
56 | C = torch.where(A <= 0.5,
57 | (2 * A - 1.0)*(B - torch.pow(B,2)) + B,
58 | (2 * A - 1.0)*(torch.sqrt(B) - B) + B
59 | )
60 | return C
61 |
62 | def hard_light(self):
63 | """
64 | if A ≤ 0.5: C = 2*A*B
65 | if A > 0.5: C = 1-2*(1-A)(1-B)
66 | """
67 | A = self.img_foreground[:, :3, :, :]
68 | B = self.img_background[:, :3, :, :]
69 | C = torch.where(A <= 0.5,
70 | 2 * A * B,
71 | 1 - 2 * (1.0 - A)*(1.0 - B)
72 | )
73 | return C
74 |
75 | def lighten(self):
76 | """
77 | if B ≤ A: C = A
78 | if B > A: C = B
79 | """
80 | A = self.img_foreground[:, :3, :, :]
81 | B = self.img_background[:, :3, :, :]
82 | C = torch.maximum(A, B)
83 | return C
84 |
85 | def darken(self):
86 | """
87 | if B ≤ A: C = B
88 | if B > A: C = A
89 | """
90 | A = self.img_foreground[:, :3, :, :]
91 | B = self.img_background[:, :3, :, :]
92 | C = torch.minimum(A, B)
93 | return C
94 |
95 | def multiply(self):
96 | """
97 | C = A * B
98 | """
99 | A = self.img_foreground[:, :3, :, :]
100 | B = self.img_background[:, :3, :, :]
101 | C = A * B
102 | return C
103 |
104 | def grain_merge(self):
105 | """
106 | C = A + B - 0.5
107 | """
108 | A = self.img_foreground[:, :3, :, :]
109 | B = self.img_background[:, :3, :, :]
110 | C = A + B - 0.5
111 | return C
112 |
113 | def _compose_alpha(self, opacity):
114 | comp = self.img_foreground[:,3,:,:]
115 |
116 | comp_alpha = comp * opacity
117 | new_alpha = comp_alpha + (1.0 - comp_alpha) * self.img_background[:,3,:,:]
118 |
119 | ratio = comp_alpha / new_alpha
120 | ratio[torch.isnan(ratio)] = 0.0
121 | ratio[torch.isinf(ratio)] = 0.0
122 |
123 | return ratio
124 |
125 | def get_blending_result(self, method, opacity):
126 | opacity = opacity.to(self.device)
127 | ratio = self._compose_alpha(opacity)
128 | comp = torch.clip(getattr(self, method)(), 0.0, 1.0)
129 | ratio_rs = torch.stack([ratio,ratio,ratio],dim=1).to(self.device)
130 | img_out = comp * ratio_rs + self.img_background[:,:3,:,:] * (1.0 - ratio_rs)
131 |
132 | alpha_channel = self.img_background[:,3,:,:]
133 | alpha_channel = alpha_channel.unsqueeze(dim=1)
134 | img_out = torch.nan_to_num(torch.cat((img_out, alpha_channel),dim=1)) # add alpha channel and replace nans
135 |
136 | return self.RGBA_to_RGB(img_out).to(self.device)
137 |
--------------------------------------------------------------------------------
/unidemoire/models/MIB/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/MIB/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/TRN/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/TRN/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/cycle/Models/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import time
3 | import datetime
4 | import sys
5 |
6 | from torch.autograd import Variable
7 | import torch
8 | import numpy as np
9 |
10 | import torch.nn as nn
11 | from torchvision.utils import save_image
12 | from math import log10, exp, sqrt, cos, pi
13 | import torch.nn.functional as F
14 |
15 | class ReplayBuffer:
16 | def __init__(self, max_size=50):
17 | assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
18 | self.max_size = max_size
19 | self.data = []
20 |
21 | def push_and_pop(self, data):
22 | to_return = []
23 | for element in data.data:
24 | element = torch.unsqueeze(element, 0)
25 | if len(self.data) < self.max_size:
26 | self.data.append(element)
27 | to_return.append(element)
28 | else:
29 | if random.uniform(0, 1) > 0.5:
30 | i = random.randint(0, self.max_size - 1)
31 | to_return.append(self.data[i].clone())
32 | self.data[i] = element
33 | else:
34 | to_return.append(element)
35 | return Variable(torch.cat(to_return))
36 |
37 |
38 | class LambdaLR:
39 | def __init__(self, n_epochs, offset, decay_start_epoch):
40 | assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
41 | self.n_epochs = n_epochs
42 | self.offset = offset
43 | self.decay_start_epoch = decay_start_epoch
44 |
45 | def step(self, epoch):
46 | return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
47 |
48 | ## DCT Transform
49 | class DCT(nn.Module):
50 | def __init__(self):
51 | super(DCT, self).__init__()
52 |
53 | conv_shape = (1, 1, 64, 64)
54 | kernel = np.zeros(conv_shape)
55 | r1 = sqrt(1.0/8)
56 | r2 = sqrt(2.0/8)
57 | for i in range(8):
58 | _u = 2*i+1
59 | for j in range(8):
60 | _v = 2*j+1
61 | index = i*8+j
62 | for u in range(8):
63 | for v in range(8):
64 | index2 = u*8+v
65 | t = cos(_u*u*pi/16)*cos(_v*v*pi/16)
66 | t = t*r1 if u==0 else t*r2
67 | t = t*r1 if v==0 else t*r2
68 | kernel[0,0,index2,index] = t
69 |
70 | self.kernel = torch.tensor(kernel, requires_grad = False, dtype=torch.float32)
71 |
72 | def forward(self, inputs):
73 |
74 | device = inputs.device
75 | kernel = self.kernel.to(device)
76 | k = kernel.permute(3, 1, 2, 0)
77 | k = torch.reshape(k, (64, 1, 8, 8))
78 |
79 | b, c, h, w = inputs.size()
80 | scale_r = h // 8
81 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8))
82 |
83 | outputs = torch.zeros_like(new_inputs)
84 |
85 | num_of_p = c * scale_r * scale_r
86 |
87 | for i in range(num_of_p):
88 | patch = new_inputs[:, i, :, :]
89 | patch = patch.unsqueeze(dim=1)
90 | patch = patch.to(device).float()
91 |
92 | new_patch = F.conv2d(patch, k, stride = 8)
93 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim = 1)
94 |
95 | outputs[:, i, :, :] = new_patch
96 |
97 | outputs = torch.reshape(outputs, (b, c, h, w))
98 |
99 | return outputs
100 |
101 | class Local_DCT(nn.Module):
102 | def __init__(self):
103 | super(Local_DCT, self).__init__()
104 |
105 | conv_shape = (1, 1, 64, 64)
106 | kernel = np.zeros(conv_shape)
107 | r1 = sqrt(1.0 / 8)
108 | r2 = sqrt(2.0 / 8)
109 | for i in range(8):
110 | _u = 2 * i + 1
111 | for j in range(8):
112 | _v = 2 * j + 1
113 | index = i * 8 + j
114 | for u in range(8):
115 | for v in range(8):
116 | index2 = u * 8 + v
117 | t = cos(_u * u * pi / 16) * cos(_v * v * pi / 16)
118 | t = t * r1 if u == 0 else t * r2
119 | t = t * r1 if v == 0 else t * r2
120 | kernel[0, 0, index2, index] = t
121 |
122 | self.kernel = torch.tensor(kernel, requires_grad=False, dtype=torch.float32)
123 |
124 | def forward(self, inputs):
125 |
126 | device = inputs.device
127 | kernel = self.kernel.to(device)
128 | k = kernel.permute(3, 1, 2, 0)
129 | k = torch.reshape(k, (64, 1, 8, 8))
130 |
131 | b, c, h, w = inputs.size()
132 | scale_r = h // 8
133 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8))
134 |
135 | outputs = torch.zeros_like(new_inputs)
136 |
137 | num_of_p = c * scale_r * scale_r
138 |
139 | for i in range(num_of_p):
140 | patch = new_inputs[:, i, :, :]
141 | patch = patch.unsqueeze(dim=1)
142 | patch = patch.to(device).float()
143 |
144 | new_patch = F.conv2d(patch, k, stride=8)
145 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim=1)
146 |
147 | outputs[:, i, :, :] = new_patch
148 |
149 | outputs = torch.reshape(outputs, (b, c, h, w))
150 |
151 | return outputs
152 |
153 | class Inverse_DCT(nn.Module):
154 | def __init__(self):
155 | super(Inverse_DCT, self).__init__()
156 |
157 | conv_shape = (1, 1, 64, 64)
158 | kernel = np.zeros(conv_shape)
159 | r1 = sqrt(1.0/8)
160 | r2 = sqrt(2.0/8)
161 | for i in range(8):
162 | _u = 2*i+1
163 | for j in range(8):
164 | _v = 2*j+1
165 | index = i*8+j
166 | for u in range(8):
167 | for v in range(8):
168 | index2 = u*8+v
169 | t = cos(_u*u*pi/16)*cos(_v*v*pi/16)
170 | t = t*r1 if u==0 else t*r2
171 | t = t*r1 if v==0 else t*r2
172 | kernel[0,0,index2,index] = t
173 |
174 | self.kernel = torch.tensor(kernel, requires_grad = False, dtype=torch.float32)
175 |
176 | self.kernel = self.kernel.permute(0, 1, 3, 2)
177 |
178 | def forward(self, inputs):
179 |
180 | device = inputs.device
181 | kernel = self.kernel.to(device)
182 | k = kernel.permute(3, 1, 2, 0)
183 | k = torch.reshape(k, (64, 1, 8, 8))
184 |
185 | b, c, h, w = inputs.size()
186 | scale_r = h // 8
187 | new_inputs = torch.reshape(inputs, (b, c * scale_r * scale_r, 8, 8))
188 |
189 | outputs = torch.zeros_like(new_inputs)
190 |
191 | num_of_p = c * scale_r * scale_r
192 |
193 | for i in range(num_of_p):
194 | patch = new_inputs[:, i, :, :]
195 | patch = patch.unsqueeze(dim=1)
196 | patch = patch.to(device).float()
197 |
198 | new_patch = F.conv2d(patch, k, stride = 8)
199 | new_patch = torch.reshape(new_patch, (b, 1, 8, 8)).squeeze(dim = 1)
200 |
201 | outputs[:, i, :, :] = new_patch
202 |
203 | outputs = torch.reshape(outputs, (b, c, h, w))
204 |
205 | return outputs.clamp(min = 0, max = 1)
206 |
207 | ## block wise mapping
208 | def block_wise_mapping(net, input, input_size, pad):
209 | b, c, _, _ = input.size()
210 | window = create_window(pad, b, c, pad // 2)
211 |
212 | pad_in = padarray(input, pad)
213 |
214 | pad_out = torch.zeros_like(pad_in)
215 | pnorm = torch.zeros_like(pad_in)
216 |
217 | device = input.device
218 |
219 | i = 0
220 | j = 0
221 |
222 | stride = pad // 2
223 |
224 | _,_, height, width = pad_in.size()
225 |
226 | while(i < height - input_size + 1):
227 | while(j < width - input_size + 1):
228 | patch = pad_in[:,:,i : i + input_size, j : j + input_size]
229 | patch = patch.to(device).float()
230 |
231 | pout = net(patch)
232 |
233 | if i < height - input_size and j < width - input_size:
234 | pout = pout[:,:,0 : 0 + pad, 0 : 0 + pad]
235 |
236 | mask = window.to(device)
237 | p_after = pout * mask
238 |
239 | pad_out[:,:,i : i + pad, j : j + pad] = pad_out[:,:,i : i + pad, j : j + pad] + p_after
240 | pnorm[:,:,i : i + pad, j : j + pad] = pnorm[:,:,i : i + pad, j : j + pad] + mask
241 | else:
242 | pad_out[:, :, i : i + input_size, j : j + input_size] = pad_out[:, :, i : i + input_size, j : j + input_size] + pout
243 | pnorm[:, :, i : i + input_size, j : j + input_size] = pnorm[:, :, i : i + input_size, j : j + input_size] + 1.0
244 |
245 | j = j + stride
246 |
247 | i = i + stride
248 | j = 0
249 |
250 | output = pad_out[:,:,0 : 1024, 0 : 1024] / pnorm[:,:,0 : 1024, 0 : 1024]
251 |
252 | return output
253 |
254 | def create_window(window_size, batch, channel, sigma):
255 | _1D_window = gaussian(window_size, sigma).unsqueeze(1)
256 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
257 | window = _2D_window.expand(batch, channel, window_size, window_size)
258 | return window
259 |
260 | def padarray(input, size_pad):
261 | b,c,h,w = input.size()
262 | device = input.device
263 |
264 | new_h = h + size_pad
265 | new_w = w + size_pad
266 | output = torch.zeros((b, c, new_h, new_w)).to(device)
267 |
268 | output[:,:,0 : h, 0 : w] = input[:,:,:,:]
269 | # output[:,:,h : new_h, w : new_w] = 0.0
270 |
271 | return output
272 |
273 | def gaussian(window_size, sigma):
274 | gauss = torch.Tensor([exp(-(x - window_size//2)**2 / float(2*sigma**2)) for x in range(window_size)])
275 | return gauss / gauss.sum()
--------------------------------------------------------------------------------
/unidemoire/models/cycle/nets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.autograd import Variable
5 |
6 | from Models import GM2_UNet5_256, GM2_UNet5_128, GM2_UNet5_64, TMB, Discriminator, L1_ASL
7 |
8 | class CycleModel(nn.Module):
9 | def __init__(self):
10 | super(CycleModel, self).__init__()
11 |
12 | self.resolution_dict = {
13 | '256': {'Net_Demoire':'256', 'G_Artifact':'256_2', 'D_moire':'256', 'D_clear':'256'},
14 | '128': {'Net_Demoire':'128', 'G_Artifact':'128_2', 'D_moire':'128', 'D_clear':'128'},
15 | '64': {'Net_Demoire':'64', 'G_Artifact':['64_2', '64_1'], 'D_moire':'64', 'D_clear':'64'},
16 | }
17 |
18 | self.Net_Demoire = {
19 | '256': GM2_UNet5_256(6, 3),
20 | '128': GM2_UNet5_128(6, 3),
21 | '64': GM2_UNet5_64(3, 3),
22 | 'TMB': TMB(256, 1)
23 | }
24 |
25 | self.G_Artifact = {
26 | '256_2': GM2_UNet5_256(6, 3),
27 | '128_2': GM2_UNet5_128(6, 3),
28 | '64_2': GM2_UNet5_64(3, 3),
29 | '64_1': TMB(256, 1),
30 | }
31 |
32 | self.D_moire = {
33 | '256': Discriminator(6, 256, 256),
34 | '128': Discriminator(6, 128, 128),
35 | '64': Discriminator(6, 64, 64),
36 | }
37 |
38 | self.D_clear = {
39 | '256': Discriminator(6, 256, 256),
40 | '128': Discriminator(6, 128, 128),
41 | '64': Discriminator(6, 64, 64),
42 | }
43 |
44 | self.downx2 = nn.UpsamplingNearest2d(scale_factor = 0.5)
45 | self.upx2 = nn.UpsamplingNearest2d(scale_factor = 2)
46 |
47 |
48 | # LOSS FUNCTIONS
49 | self.criterion_GAN = torch.nn.MSELoss()
50 | self.criterion_cycle = torch.nn.L1Loss()
51 | self.criterion_MSE = torch.nn.MSELoss()
52 | self.criterion_content = L1_ASL()
53 | self.Loss = L1_ASL()
54 |
55 | # Initialize weights
56 | for key in self.Net_Demoire.keys():
57 | self.Net_Demoire[key].apply(self.weights_init)
58 |
59 | for key in self.G_Artifact.keys():
60 | self.G_Artifact[key].apply(self.weights_init)
61 |
62 | for key in self.D_moire.keys():
63 | self.D_moire[key].apply(self.weights_init)
64 |
65 | for key in self.D_clear.keys():
66 | self.D_clear[key].apply(self.weights_init)
67 |
68 |
69 | # Custom weights initialization called on network
70 | def weights_init(m):
71 | if isinstance(m, nn.Conv2d):
72 | nn.init.kaiming_uniform_(m.weight)
73 | if m.bias is not None:
74 | m.bias.data.zero_()
75 |
76 |
77 | def forward(self, MOIRE, CLEAR, historgram, device):
78 |
79 | Tensor = torch.cuda.FloatTensor
80 |
81 | # load data
82 | MOIRE_256 = MOIRE
83 | MOIRE_128 = self.downx2(MOIRE_256)
84 | MOIRE_64 = self.downx2(MOIRE_128)
85 |
86 | CLEAR_256 = CLEAR
87 | CLEAR_128 = self.downx2(CLEAR_256)
88 | CLEAR_64 = self.downx2(CLEAR_128)
89 |
90 | historgram = historgram.float()
91 |
92 | valid_256 = Variable(Tensor(MOIRE_256.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False)
93 | fake_256 = Variable(Tensor(MOIRE_256.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False)
94 |
95 | valid_128 = Variable(Tensor(MOIRE_128.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False)
96 | fake_128 = Variable(Tensor(MOIRE_128.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False)
97 |
98 | valid_64 = Variable(Tensor(MOIRE_64.size(0), 1, 1, 1).fill_(1.0).to(device), requires_grad=False)
99 | fake_64 = Variable(Tensor(MOIRE_64.size(0), 1, 1, 1).fill_(0.0).to(device), requires_grad=False)
100 |
101 | for resolution in self.resolution_dict.keys():
102 |
103 | Net_Demoire = self.Net_Demoire[self.resolution_dict[resolution]['Net_Demoire']]
104 |
105 | if resolution == '64':
106 | G_Artifact_1 = self.G_Artifact[self.resolution_dict[resolution]['G_Artifact'][0]]
107 | G_Artifact_2 = self.G_Artifact[self.resolution_dict[resolution]['G_Artifact']] if resolution != '64' else self.G_Artifact[self.resolution_dict[resolution]['G_Artifact'][1]]
108 |
109 | D_moire = self.D_moire[self.resolution_dict[resolution]['D_moire']]
110 | D_clear = self.D_clear[self.resolution_dict[resolution]['D_clear']]
111 |
112 |
--------------------------------------------------------------------------------
/unidemoire/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/esdnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/esdnet/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/esdnet/nets.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ESDNet for image demoireing
3 | """
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torch.nn.parameter import Parameter
11 |
12 | class ESDNet(nn.Module):
13 | def __init__(self,
14 | en_feature_num,
15 | en_inter_num,
16 | de_feature_num,
17 | de_inter_num,
18 | sam_number=1,
19 | ):
20 | super(ESDNet, self).__init__()
21 | self.encoder = Encoder(feature_num=en_feature_num, inter_num=en_inter_num, sam_number=sam_number)
22 | self.decoder = Decoder(en_num=en_feature_num, feature_num=de_feature_num, inter_num=de_inter_num,
23 | sam_number=sam_number)
24 |
25 | def forward(self, x):
26 | y_1, y_2, y_3 = self.encoder(x)
27 | out_1, out_2, out_3 = self.decoder(y_1, y_2, y_3)
28 |
29 | return out_1, out_2, out_3
30 |
31 | def _initialize_weights(self):
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | m.weight.data.normal_(0.0, 0.02)
35 | if m.bias is not None:
36 | m.bias.data.normal_(0.0, 0.02)
37 | if isinstance(m, nn.ConvTranspose2d):
38 | m.weight.data.normal_(0.0, 0.02)
39 |
40 |
41 | class Decoder(nn.Module):
42 | def __init__(self, en_num, feature_num, inter_num, sam_number):
43 | super(Decoder, self).__init__()
44 | self.preconv_3 = conv_relu(4 * en_num, feature_num, 3, padding=1)
45 | self.decoder_3 = Decoder_Level(feature_num, inter_num, sam_number)
46 |
47 | self.preconv_2 = conv_relu(2 * en_num + feature_num, feature_num, 3, padding=1)
48 | self.decoder_2 = Decoder_Level(feature_num, inter_num, sam_number)
49 |
50 | self.preconv_1 = conv_relu(en_num + feature_num, feature_num, 3, padding=1)
51 | self.decoder_1 = Decoder_Level(feature_num, inter_num, sam_number)
52 |
53 | def forward(self, y_1, y_2, y_3):
54 | x_3 = y_3
55 | x_3 = self.preconv_3(x_3)
56 | out_3, feat_3 = self.decoder_3(x_3)
57 |
58 | x_2 = torch.cat([y_2, feat_3], dim=1)
59 | x_2 = self.preconv_2(x_2)
60 | out_2, feat_2 = self.decoder_2(x_2)
61 |
62 | x_1 = torch.cat([y_1, feat_2], dim=1)
63 | x_1 = self.preconv_1(x_1)
64 | out_1 = self.decoder_1(x_1, feat=False)
65 |
66 | return out_1, out_2, out_3
67 |
68 |
69 | class Encoder(nn.Module):
70 | def __init__(self, feature_num, inter_num, sam_number):
71 | super(Encoder, self).__init__()
72 | self.conv_first = nn.Sequential(
73 | nn.Conv2d(12, feature_num, kernel_size=5, stride=1, padding=2, bias=True),
74 | nn.ReLU(inplace=True)
75 | )
76 | self.encoder_1 = Encoder_Level(feature_num, inter_num, level=1, sam_number=sam_number)
77 | self.encoder_2 = Encoder_Level(2 * feature_num, inter_num, level=2, sam_number=sam_number)
78 | self.encoder_3 = Encoder_Level(4 * feature_num, inter_num, level=3, sam_number=sam_number)
79 |
80 | def forward(self, x):
81 | x = F.pixel_unshuffle(x, 2)
82 | x = self.conv_first(x)
83 |
84 | out_feature_1, down_feature_1 = self.encoder_1(x)
85 | out_feature_2, down_feature_2 = self.encoder_2(down_feature_1)
86 | out_feature_3 = self.encoder_3(down_feature_2)
87 |
88 | return out_feature_1, out_feature_2, out_feature_3
89 |
90 |
91 | class Encoder_Level(nn.Module):
92 | def __init__(self, feature_num, inter_num, level, sam_number):
93 | super(Encoder_Level, self).__init__()
94 | self.rdb = RDB(in_channel=feature_num, d_list=(1, 2, 1), inter_num=inter_num)
95 | self.sam_blocks = nn.ModuleList()
96 | for _ in range(sam_number):
97 | sam_block = SAM(in_channel=feature_num, d_list=(1, 2, 3, 2, 1), inter_num=inter_num)
98 | self.sam_blocks.append(sam_block)
99 |
100 | if level < 3:
101 | self.down = nn.Sequential(
102 | nn.Conv2d(feature_num, 2 * feature_num, kernel_size=3, stride=2, padding=1, bias=True),
103 | nn.ReLU(inplace=True)
104 | )
105 | self.level = level
106 |
107 | def forward(self, x):
108 | out_feature = self.rdb(x)
109 | for sam_block in self.sam_blocks:
110 | out_feature = sam_block(out_feature)
111 | if self.level < 3:
112 | down_feature = self.down(out_feature)
113 | return out_feature, down_feature
114 | return out_feature
115 |
116 |
117 | class Decoder_Level(nn.Module):
118 | def __init__(self, feature_num, inter_num, sam_number):
119 | super(Decoder_Level, self).__init__()
120 | self.rdb = RDB(feature_num, (1, 2, 1), inter_num)
121 | self.sam_blocks = nn.ModuleList()
122 | for _ in range(sam_number):
123 | sam_block = SAM(in_channel=feature_num, d_list=(1, 2, 3, 2, 1), inter_num=inter_num)
124 | self.sam_blocks.append(sam_block)
125 | self.conv = conv(in_channel=feature_num, out_channel=12, kernel_size=3, padding=1)
126 |
127 | def forward(self, x, feat=True):
128 | x = self.rdb(x)
129 | for sam_block in self.sam_blocks:
130 | x = sam_block(x)
131 | out = self.conv(x)
132 | out = F.pixel_shuffle(out, 2)
133 |
134 | if feat:
135 | feature = F.interpolate(x, scale_factor=2, mode='bilinear')
136 | return out, feature
137 | else:
138 | return out
139 |
140 |
141 | class DB(nn.Module):
142 | def __init__(self, in_channel, d_list, inter_num):
143 | super(DB, self).__init__()
144 | self.d_list = d_list
145 | self.conv_layers = nn.ModuleList()
146 | c = in_channel
147 | for i in range(len(d_list)):
148 | dense_conv = conv_relu(in_channel=c, out_channel=inter_num, kernel_size=3, dilation_rate=d_list[i],
149 | padding=d_list[i])
150 | self.conv_layers.append(dense_conv)
151 | c = c + inter_num
152 | self.conv_post = conv(in_channel=c, out_channel=in_channel, kernel_size=1)
153 |
154 | def forward(self, x):
155 | t = x
156 | for conv_layer in self.conv_layers:
157 | _t = conv_layer(t)
158 | t = torch.cat([_t, t], dim=1)
159 | t = self.conv_post(t)
160 | return t
161 |
162 |
163 | class SAM(nn.Module):
164 | def __init__(self, in_channel, d_list, inter_num):
165 | super(SAM, self).__init__()
166 | self.basic_block = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num)
167 | self.basic_block_2 = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num)
168 | self.basic_block_4 = DB(in_channel=in_channel, d_list=d_list, inter_num=inter_num)
169 | self.fusion = CSAF(3 * in_channel)
170 |
171 | def forward(self, x):
172 | x_0 = x
173 | x_2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
174 | x_4 = F.interpolate(x, scale_factor=0.25, mode='bilinear')
175 |
176 | y_0 = self.basic_block(x_0)
177 | y_2 = self.basic_block_2(x_2)
178 | y_4 = self.basic_block_4(x_4)
179 |
180 | y_2 = F.interpolate(y_2, scale_factor=2, mode='bilinear')
181 | y_4 = F.interpolate(y_4, scale_factor=4, mode='bilinear')
182 |
183 | y = self.fusion(y_0, y_2, y_4)
184 | y = x + y
185 |
186 | return y
187 |
188 |
189 | class CSAF(nn.Module):
190 | def __init__(self, in_chnls, ratio=4):
191 | super(CSAF, self).__init__()
192 | self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
193 | self.compress1 = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0)
194 | self.compress2 = nn.Conv2d(in_chnls // ratio, in_chnls // ratio, 1, 1, 0)
195 | self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0)
196 |
197 | def forward(self, x0, x2, x4):
198 | out0 = self.squeeze(x0)
199 | out2 = self.squeeze(x2)
200 | out4 = self.squeeze(x4)
201 | out = torch.cat([out0, out2, out4], dim=1)
202 | out = self.compress1(out)
203 | out = F.relu(out)
204 | out = self.compress2(out)
205 | out = F.relu(out)
206 | out = self.excitation(out)
207 | out = F.sigmoid(out)
208 | w0, w2, w4 = torch.chunk(out, 3, dim=1)
209 | x = x0 * w0 + x2 * w2 + x4 * w4
210 |
211 | return x
212 |
213 |
214 | class RDB(nn.Module):
215 | def __init__(self, in_channel, d_list, inter_num):
216 | super(RDB, self).__init__()
217 | self.d_list = d_list
218 | self.conv_layers = nn.ModuleList()
219 | c = in_channel
220 | for i in range(len(d_list)):
221 | dense_conv = conv_relu(in_channel=c, out_channel=inter_num, kernel_size=3, dilation_rate=d_list[i],
222 | padding=d_list[i])
223 | self.conv_layers.append(dense_conv)
224 | c = c + inter_num
225 | self.conv_post = conv(in_channel=c, out_channel=in_channel, kernel_size=1)
226 |
227 | def forward(self, x):
228 | t = x
229 | for conv_layer in self.conv_layers:
230 | _t = conv_layer(t)
231 | t = torch.cat([_t, t], dim=1)
232 |
233 | t = self.conv_post(t)
234 | return t + x
235 |
236 |
237 | class conv(nn.Module):
238 | def __init__(self, in_channel, out_channel, kernel_size, dilation_rate=1, padding=0, stride=1):
239 | super(conv, self).__init__()
240 | self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride,
241 | padding=padding, bias=True, dilation=dilation_rate)
242 |
243 | def forward(self, x_input):
244 | out = self.conv(x_input)
245 | return out
246 |
247 |
248 | class conv_relu(nn.Module):
249 | def __init__(self, in_channel, out_channel, kernel_size, dilation_rate=1, padding=0, stride=1):
250 | super(conv_relu, self).__init__()
251 | self.conv = nn.Sequential(
252 | nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride,
253 | padding=padding, bias=True, dilation=dilation_rate),
254 | nn.ReLU(inplace=True)
255 | )
256 |
257 | def forward(self, x_input):
258 | out = self.conv(x_input)
259 | return out
260 |
--------------------------------------------------------------------------------
/unidemoire/models/mbcnn/MBCNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .MBCNN_class import *
6 | import torch.nn.functional as F
7 | import pdb
8 |
9 | class MBCNN(nn.Module):
10 | def __init__(self, nFilters, multi=True):
11 | super().__init__()
12 | self.imagesize = 256
13 | self.sigmoid = nn.Sigmoid()
14 | self.Space2Depth1 = nn.PixelUnshuffle(2)
15 | self.Depth2space1 = nn.PixelShuffle(2)
16 |
17 | self.conv_func1 = conv_relu1(12, nFilters * 2, 3, padding=1)
18 | self.pre_block1 = pre_block((1, 2, 3, 2, 1))
19 | self.conv_func2 = conv_relu1(128, nFilters * 2, 3, padding=0, stride=2)
20 | self.pre_block2 = pre_block((1, 2, 3, 2, 1))
21 |
22 | self.conv_func3 = conv_relu1(128, nFilters * 2, 3, padding=0, stride=2)
23 | self.pre_block3 = pre_block((1, 2, 2, 2, 1))
24 | self.global_block1 = global_block(self.imagesize // 8)
25 | self.pos_block1 = pos_block((1, 2, 2, 2, 1))
26 | self.conv1 = conv1(128, 12, 3,us=[True,False])
27 |
28 | self.conv_func4 = conv_relu1(131, nFilters * 2, 1, padding=0,cat_shape=(3,nFilters*2),set_cat_mul=(False,True))
29 | self.global_block2 = global_block(self.imagesize // 4)
30 | self.pre_block4 = pre_block((1, 2, 3, 2, 1))
31 | self.global_block3 = global_block(self.imagesize // 4)
32 | self.pos_block2 = pos_block((1, 2, 3, 2, 1))
33 | self.conv2 = conv1(128, 12, 3,us=[True,False])
34 |
35 | self.conv_func5 = conv_relu1(131, nFilters * 2, 1, padding=0,cat_shape=(3,nFilters*2),set_cat_mul=(False,True))
36 |
37 | self.global_block4 = global_block(self.imagesize // 2)
38 | self.pre_block5 = pre_block((1, 2, 3, 2, 1))
39 | self.global_block5 = global_block(self.imagesize // 2)
40 | self.pos_block3 = pos_block((1, 2, 3, 2, 1))
41 | self.conv3 = conv1(128, 12, 3,us=[True,False])
42 |
43 | def forward(self, x):
44 | output_list = []
45 | shape = list(x.shape) # [2, 3, 512, 512]
46 | # batch, channel, height, width = shape
47 | _x = self.Space2Depth1(x)
48 | t1 = self.conv_func1(_x)
49 | t1 = self.pre_block1(t1)
50 |
51 | t2 = F.pad(t1, (1, 1, 1, 1))
52 | t2 = self.conv_func2(t2)
53 | t2 = self.pre_block2(t2)
54 | t3 = F.pad(t2, (1, 1, 1, 1))
55 | t3 = self.conv_func3(t3)
56 | t3 = self.pre_block3(t3)
57 | t3 = self.global_block1(t3)
58 | t3 = self.pos_block1(t3)
59 | t3_out = self.conv1(t3)
60 | t3_out = self.Depth2space1(t3_out)
61 | t3_out = F.sigmoid(t3_out)
62 | output_list.append(t3_out)
63 |
64 | _t2 = torch.cat([t3_out, t2], dim=-3)
65 | _t2 = self.conv_func4(_t2)
66 | _t2 = self.global_block2(_t2)
67 | _t2 = self.pre_block4(_t2)
68 | _t2 = self.global_block3(_t2)
69 | _t2 = self.pos_block2(_t2)
70 | t2_out = self.conv2(_t2)
71 | t2_out = self.Depth2space1(t2_out)
72 | t2_out = F.sigmoid(t2_out)
73 | output_list.append(t2_out)
74 |
75 | _t1 = torch.cat([t1, t2_out], dim=-3)
76 | _t1 = self.conv_func5(_t1)
77 | _t1 = self.global_block4(_t1)
78 | _t1 = self.pre_block5(_t1)
79 | _t1 = self.global_block5(_t1)
80 | _t1 = self.pos_block3(_t1)
81 | _t1 = self.conv3(_t1)
82 | y = self.Depth2space1(_t1)
83 |
84 | y = self.sigmoid(y) + torch.Tensor([1e-10]).to(_t1.device)
85 | output_list.append(y)
86 | return t3_out,t2_out,y
87 | #return output_list
88 |
89 | # import os
90 | # from torchinfo import summary
91 | # from rich import print
92 | # GPU_ID = 5
93 | # os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % GPU_ID
94 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
95 | # net = MBCNN(64).to(device)
96 | # #print(summary(net, input_size=(1, 3, 512, 512)))
97 | # # print(summary(net))
98 |
99 | # # model_stats = summary(
100 | # # net,
101 | # # input_size=(1, 3, 512, 512),
102 | # # verbose=1,
103 | # # col_names=["kernel_size", "output_size", "num_params"],
104 | # # row_settings=["var_names"],
105 | # # )
106 |
--------------------------------------------------------------------------------
/unidemoire/models/mbcnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/mbcnn/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/mbcnn/arch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | import math
6 | import numpy as np
7 | from torch import Tensor
8 | from typing import Optional, List
9 | import pdb
10 |
11 | def make_divisible(v, divisor=8, min_value=8):
12 | if min_value is None:
13 | min_value = divisor
14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
15 | # Make sure that round down does not go down by more than 10%.
16 | if new_v < 0.9 * v:
17 | new_v += divisor
18 | return int(new_v)
19 |
20 | def initialize_weights(net_l, scale=1):
21 | if not isinstance(net_l, list):
22 | net_l = [net_l]
23 | for net in net_l:
24 | for m in net.modules():
25 | if isinstance(m, nn.Conv2d):
26 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
27 | m.weight.data *= scale # for residual block
28 | if m.bias is not None:
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.Linear):
31 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
32 | m.weight.data *= scale
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif isinstance(m, nn.BatchNorm2d):
36 | init.constant_(m.weight, 1)
37 | init.constant_(m.bias.data, 0.0)
38 |
39 | class MeanShift(nn.Conv2d):
40 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
41 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
42 | std = torch.Tensor(rgb_std)
43 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
44 | self.weight.data.div_(std.view(3, 1, 1, 1))
45 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
46 | self.bias.data.div_(std)
47 | # self.requires_grad = False
48 | for p in self.parameters():
49 | p.requires_grad = False
50 |
51 | class USConv2d(nn.Conv2d):
52 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, us=[False, False],cat_shape=None,set_cat_mul=None):
53 | super(USConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
54 | self.width_mult = 1
55 | self.us = us
56 | self.cat_shape = cat_shape
57 | self.set_cat_mul = set_cat_mul
58 | self.in_channels_index_list = [None]*4
59 |
60 | self.unrank = False
61 |
62 | def forward(self, inputs):
63 | in_channels = inputs.shape[1] // self.groups if self.us[0] else self.in_channels // self.groups
64 | out_channels = int(self.out_channels * self.width_mult) if self.us[1] else self.out_channels
65 | if self.width_mult < 0.3:
66 | in_channels_index=self.in_channels_index_list[-1]
67 | elif self.width_mult < 0.6:
68 | in_channels_index=self.in_channels_index_list[-2]
69 | elif self.width_mult < 0.8:
70 | in_channels_index=self.in_channels_index_list[-3]
71 | else:
72 | in_channels_index=self.in_channels_index_list[-4]
73 |
74 | if in_channels == self.in_channels:
75 | weight = self.weight[:out_channels, :in_channels, :, :]
76 | elif in_channels_index is None and self.cat_shape is not None:
77 | if self.set_cat_mul is None:
78 | cat_num = len(self.cat_shape)
79 | inchannel_index = np.zeros(self.in_channels)
80 | start = 0
81 | for i in range(cat_num):
82 | inchannel_index[start:start+int(self.width_mult*self.cat_shape[i])]=1
83 | start += self.cat_shape[i]
84 | else:
85 | assert len(self.set_cat_mul) == len(self.cat_shape), 'USconv2d use cat now and partially prune, need len(self.set_cat_mul) == len(self.cat_shape)'
86 | inchannel_index = np.zeros(self.in_channels)
87 | start=0
88 | for i in range(len(self.set_cat_mul)):
89 | if self.set_cat_mul[i] == True:
90 | inchannel_index[start:start+int(self.width_mult*self.cat_shape[i])]=1
91 | else:
92 | inchannel_index[start:start+int(self.cat_shape[i])]=1
93 | start += self.cat_shape[i]
94 | # pdb.set_trace()
95 | inchannel_index = np.squeeze(np.argwhere(inchannel_index))
96 | in_channels_index = inchannel_index
97 | weight = self.weight[:out_channels,inchannel_index, :, :]
98 | elif in_channels_index is not None and self.cat_shape is not None:
99 | inchannel_index = in_channels_index
100 | weight = self.weight[:out_channels,inchannel_index, :, :]
101 | else:
102 | weight = self.weight[:out_channels, :in_channels, :, :]
103 |
104 | if self.bias is not None:
105 | bias = self.bias[:out_channels]
106 | else:
107 | bias = self.bias
108 | y = F.conv2d(inputs, weight, bias, self.stride, self.padding, self.dilation, self.groups)
109 | # self.y = y
110 | return y
111 |
112 |
113 | class USConvTranspose2d(nn.ConvTranspose2d):
114 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, us=[False, False]):
115 | super(USConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding = output_padding)
116 | self.width_mult = None
117 | self.us = us
118 |
119 | def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
120 | # in_channels = make_divisible(self.in_channels * self.width_mult) if self.us[0] else self.in_channels
121 | in_channels = int(self.in_channels * self.width_mult) if self.us[0] else self.in_channels
122 | out_channels = input.shape[1] if self.us[1] else self.out_channels
123 |
124 |
125 | weight = self.weight[:in_channels, :out_channels, :, :]
126 |
127 | assert isinstance(self.padding, tuple)
128 | output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
129 |
130 | return F.conv_transpose2d(input, weight, self.bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
131 |
132 |
133 | class USBatchNorm2d(nn.BatchNorm2d):
134 | def __init__(self, num_features, width_list = None):
135 | super(USBatchNorm2d, self).__init__(num_features, affine=True, track_running_stats=False)
136 | self.width_id = None
137 |
138 | self.bn = nn.ModuleList([
139 | nn.BatchNorm2d(self.num_features, affine=False) for _ in range(len(width_list))
140 | ])
141 | # raise NotImplementedError
142 |
143 | def forward(self, inputs):
144 | num_features = inputs.size(1)
145 | y = F.batch_norm(
146 | inputs,
147 | self.bn[self.width_id].running_mean[:num_features],
148 | self.bn[self.width_id].running_var[:num_features],
149 | self.weight[:num_features],
150 | self.bias[:num_features],
151 | self.training,
152 | self.momentum,
153 | self.eps)
154 | return y
155 |
--------------------------------------------------------------------------------
/unidemoire/models/moire_blending.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | # from torch.nn.parameter import Parameter
5 | from torch.optim import lr_scheduler
6 | import torch.optim as optim
7 | from omegaconf import OmegaConf
8 | import glob
9 |
10 | from .MIB.Blending import Blending
11 | from .TRN.model import Uformer
12 |
13 | from .utils.loss_util import *
14 | from .utils.common import *
15 |
16 | torch.autograd.set_detect_anomaly(True)
17 |
18 | class MoireBlending_Model(pl.LightningModule):
19 | def __init__(self, model_name, network_config, loss_config=None, optimizer_config=None, ckpt_path=None, ignore_keys=[]):
20 | super().__init__()
21 | self.model_name = model_name
22 | self.network_config = network_config
23 | self.loss_config = loss_config
24 | self.optimizer_config = optimizer_config
25 | self.init_blending_args = network_config["init_blending_args"]
26 | self.blending_network_args = network_config["blending_network_args"]
27 |
28 | # model
29 | self.model = self.build_up_models()
30 | self.loss_fn = self.loss_function()
31 |
32 | if self.model_name == "UniDemoire":
33 | self.init_blend, self.refine_net = self.model
34 | if ckpt_path is not None:
35 | print(f"Loading Checkpoint from {ckpt_path}")
36 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
37 |
38 |
39 | def on_load_checkpoint(self, checkpoint):
40 | print("Loading checkpoint...")
41 |
42 |
43 | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
44 | # only for very first batch
45 | if self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0:
46 | self.moire_image_encoder = None
47 |
48 | def get_config_from_ckpt_path(self, ckpt_path):
49 | if os.path.isfile(ckpt_path):
50 | # paths = opt.resume.split("/")
51 | try:
52 | logdir = '/'.join(ckpt_path.split('/')[:-1])
53 | print(f'Encoder dir is {logdir}')
54 | except ValueError:
55 | paths = ckpt_path.split("/")
56 | idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
57 | logdir = "/".join(paths[:idx])
58 | ckpt = ckpt_path
59 | else:
60 | assert os.path.isdir(ckpt_path), f"{ckpt_path} is not a directory"
61 | logdir = ckpt_path.rstrip("/")
62 | ckpt = os.path.join(logdir, "model.ckpt")
63 |
64 | base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
65 | base = base_configs
66 | configs = [OmegaConf.load(cfg) for cfg in base]
67 | return configs[0]['model']
68 |
69 | def val_mode(self, model):
70 | model = model.eval()
71 | for param in model.parameters():
72 | param.requires_grad = False
73 | return model
74 |
75 | def init_from_ckpt(self, path, ignore_keys=list()):
76 | sd = torch.load(path, map_location="cpu")["state_dict"]
77 | keys = list(sd.keys())
78 | for k in keys:
79 | for ik in ignore_keys:
80 | if k.startswith(ik):
81 | print("Deleting key {} from state_dict.".format(k))
82 | del sd[k]
83 | self.load_state_dict(sd, strict=False)
84 | # print(f"Restored from {path}, self.mib_weight = {self.mib_weight}, self.init_blend.weight = {self.init_blend.weight}", sd['init_blend.weight'])
85 |
86 | def build_up_models(self):
87 | if self.model_name == "UniDemoire":
88 | init_blend = Blending(self.init_blending_args)
89 | refine_net = Uformer(
90 | embed_dim=self.blending_network_args['embed_dim'],
91 | depths=self.blending_network_args['depths'],
92 | win_size=self.blending_network_args['win_size'],
93 | modulator=self.blending_network_args['modulator'],
94 | shift_flag=self.blending_network_args['shift_flag']
95 | )
96 | model = [init_blend, refine_net]
97 | else:
98 | model = None
99 | return model
100 |
101 | def loss_function(self):
102 | if self.model_name == "UniDemoire":
103 | Perceptual_Loss = PerceptualLoss()
104 | TV_Loss = TVLoss()
105 | ColorHistogram_Loss = ColorHistogramMatchingLoss()
106 | loss_fn = [Perceptual_Loss, TV_Loss, ColorHistogram_Loss]
107 | else:
108 | loss_fn = []
109 |
110 | return loss_fn
111 |
112 | def setup_optimizer(self):
113 | if self.model_name == "UniDemoire":
114 | optimizer = optim.Adam(
115 | [{
116 | 'params':
117 | list(self.model[1].parameters()), # self.refine_net
118 | 'initial_lr':
119 | self.learning_rate,
120 | 'lr': self.learning_rate
121 | }],
122 | betas=(self.optimizer_config["beta1"], self.optimizer_config["beta2"])
123 | )
124 | else:
125 | optimizer = optim.Adam(params=self.model.parameters(), lr=self.learning_rate)
126 |
127 | return optimizer
128 |
129 | def setup_scheduler(self):
130 | if self.model_name == "UniDemoire":
131 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts(
132 | self.optimizer,
133 | T_0=self.optimizer_config["T_0"],
134 | T_mult=self.optimizer_config["T_mult"],
135 | eta_min=self.optimizer_config["eta_min"],
136 | )
137 | else:
138 | scheduler = None
139 |
140 | return scheduler
141 |
142 | def configure_optimizers(self):
143 | self.optimizer = self.setup_optimizer()
144 | self.scheduler = self.setup_scheduler()
145 | if self.scheduler is not None:
146 | return [self.optimizer],[self.scheduler]
147 | else:
148 | return self.optimizer
149 |
150 | def training_epoch_end(self, outputs):
151 | if self.scheduler is not None:
152 | self.scheduler.step()
153 |
154 | def get_input(self, batch):
155 | moire_pattern = batch['moire_pattern']
156 | natural = batch['natural']
157 | real_moire = batch['real_moire']
158 | number = batch['number']
159 |
160 | return moire_pattern, natural, real_moire, number
161 |
162 | def forward(self, moire_pattern, natural, real_moire):
163 | if self.model_name == "UniDemoire":
164 | moire_pattern = moire_pattern.to(self.device)
165 | natural = natural.to(self.device)
166 | real_moire = real_moire.to(self.device)
167 | self.init_blend.to(self.device)
168 |
169 | ##* Here's the MIB:
170 | mib_result, weight = self.init_blend(natural, moire_pattern)
171 | mib_result = mib_result.to(self.device)
172 | self.log('w_mib', weight, prog_bar=True, logger=True)
173 |
174 | ##* And here's the TRN:
175 | refine_result = mib_result * self.refine_net(mib_result, real_moire)
176 | min_val = torch.min(refine_result)
177 | max_val = torch.max(refine_result)
178 | refine_result = (refine_result - min_val) / (max_val - min_val)
179 | refine_result = refine_result
180 |
181 | return mib_result, refine_result
182 | else:
183 | return None
184 |
185 | def training_step(self, batch, batch_idx):
186 | if self.model_name == "UniDemoire":
187 | # Get data
188 | moire_pattern, natural, real_moire, number = self.get_input(batch)
189 | # Get Loss function
190 | Perceptual_Loss, TV_Loss, ColorHistogram_Loss = self.loss_fn
191 | #* Get the result
192 | mib_result, refine_result = self(moire_pattern, natural, real_moire)
193 |
194 | #* Calculate the losses
195 | content_loss = Perceptual_Loss(input=refine_result, target=mib_result, device=self.device, feature_layers=[0,1,2])
196 | color_loss = ColorHistogram_Loss(x=refine_result, y=real_moire, device=self.device)
197 | tv_loss = TV_Loss(refine_result)
198 |
199 | #** Total Loss:
200 | loss = color_loss + content_loss + 0.1 * tv_loss
201 |
202 | # Logging
203 | self.log('L_p', content_loss, prog_bar=True, logger=True)
204 | self.log('L_c', color_loss, prog_bar=True, logger=True)
205 | self.log('L_tv', tv_loss, prog_bar=True, logger=True)
206 | self.log('L_total', loss, prog_bar=False, logger=True)
207 |
208 | lr = self.optimizer.param_groups[0]['lr']
209 | self.log('lr', lr, prog_bar=True, logger=False)
210 |
211 | return loss
212 |
213 | def feature_norm(self, feature):
214 | normed_feature = feature / feature.norm(dim=-1, keepdim=True)
215 | return normed_feature
216 |
217 | @torch.no_grad()
218 | def log_images(self, batch, only_inputs=False, **kwargs):
219 | log = dict()
220 | moire_pattern, natural, real_moire, number = self.get_input(batch)
221 | log["natural"] = natural
222 | log["moire_pattern"] = moire_pattern
223 | log["real_moire"] = real_moire
224 | if not only_inputs:
225 | mib_result, trn_result = self(moire_pattern, natural, real_moire)
226 | log["init_blending_result"] = mib_result
227 | log["fusion_result"] = trn_result
228 |
229 | return log
--------------------------------------------------------------------------------
/unidemoire/models/pmtnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/pmtnet/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/shooting/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/shooting/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/shooting/image_transformer.py:
--------------------------------------------------------------------------------
1 | #from utils import *
2 | import numpy as np
3 | import cv2
4 | from math import pi
5 | import torch
6 | import time
7 | import torch.nn.functional as F
8 |
9 | class ImageTransformer(object):
10 | """ Perspective transformation class for image
11 | with shape (height, width, #channels) """
12 |
13 | def __init__(self, img):
14 | self.image = img # (h, w, c)
15 | self.image = self.image.unsqueeze(0) # (1, h, w, c)
16 | # self.image = self.image.permute(0, 3, 1, 2) # (1, c, h, w)
17 | self.batchsize = self.image.shape[0]
18 | self.num_channels = self.image.shape[1]
19 | self.height = self.image.shape[2]
20 | self.width = self.image.shape[3]
21 | self.device = img.device
22 |
23 | def get_rad(self, theta, phi, gamma):
24 | return (self.deg_to_rad(theta),
25 | self.deg_to_rad(phi),
26 | self.deg_to_rad(gamma))
27 |
28 | def get_deg(self, rtheta, rphi, rgamma):
29 | return (self.rad_to_deg(rtheta),
30 | self.rad_to_deg(rphi),
31 | self.rad_to_deg(rgamma))
32 |
33 | def deg_to_rad(self, deg):
34 | return deg * pi / 180.0
35 |
36 | def rad_to_deg(self, rad):
37 | return rad * 180.0 / pi
38 |
39 | """ Wrapper of Rotating a Image """
40 | def rotate_along_axis(self, random_f, theta=0, phi=0, gamma=0, dx=0, dy=0, dz=0):
41 |
42 | # Get radius of rotation along 3 axes
43 | if random_f:
44 | theta = np.random.randint(-20, 20)
45 | phi = np.random.randint(-20, 20)
46 | gamma = np.random.randint(-20, 20)
47 |
48 | # theta = 0
49 | # phi = 0
50 | # gamma = 0
51 | rtheta, rphi, rgamma =self.get_rad(theta, phi, gamma)
52 |
53 | # Get ideal focal length on z axis
54 | # NOTE: Change this section to other axis if needed
55 | d = np.sqrt(self.height**2 + self.width**2)
56 |
57 | self.focal = d / (2 * np.sin(rgamma) if np.sin(rgamma) != 0 else 1)
58 | dz = self.focal
59 |
60 | # Get projection matrix
61 | mat = self.get_M(rtheta, rphi, rgamma, dx, dy, dz)
62 |
63 | # print(type(mat), mat.shape)
64 | # mat_inv = np.linalg.pinv(mat)
65 | mat_inv = mat
66 |
67 | time.sleep(0.1)
68 | rotate_img = cv2.warpPerspective(self.image.cpu().numpy(), mat, (self.width, self.height))
69 | # rotate_img = self.image.cpu()
70 |
71 | rotate_img = torch.from_numpy(rotate_img)
72 | return theta, phi, gamma, rotate_img, mat_inv, mat
73 |
74 | def Perspective(self, random_f, theta=0, phi=0, gamma=0, dx=0, dy=0, dz=0):
75 |
76 | # Get radius of rotation along 3 axes
77 | if random_f:
78 | theta = torch.randint(-20,20,(1,))
79 | phi = torch.randint(-20,20,(1,))
80 | gamma = torch.randint(-20,20,(1,))
81 | rtheta, rphi, rgamma =self.get_rad(theta, phi, gamma)
82 |
83 | # Get ideal focal length on z axis
84 | # NOTE: Change this section to other axis if needed
85 | d = torch.sqrt(torch.tensor(self.height**2) + torch.tensor(self.width**2))
86 | self.focal = d / (2 * torch.sin(rgamma) if torch.sin(rgamma) != 0 else 1)
87 | dz = self.focal
88 |
89 | # Get projection matrix
90 | mat = self.get_M_2(rtheta, rphi, rgamma, dx, dy, dz)
91 |
92 | # rotate_img = cv2.warpPerspective(self.image.cpu().numpy(), mat, (self.width, self.height))
93 | rotate_img = self.warpPerspective(image=self.image, M=mat)
94 |
95 | return theta, phi, gamma, rotate_img
96 |
97 |
98 | def warpPerspective(self, image, M):
99 | M_norm = self.matrix_normalization(M)
100 | grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), image.size(), align_corners=False).to(self.device)
101 | homogeneous_grid = torch.cat([grid, torch.ones(self.batchsize, self.height, self.width, 1, device=self.device)], dim=-1)
102 |
103 | warped_grid = torch.matmul(homogeneous_grid, M_norm.transpose(1, 2))
104 | warped_grid_xy = warped_grid[..., :2] / warped_grid[..., 2:3]
105 |
106 | transformed_image = F.grid_sample(image, warped_grid_xy, align_corners=False, padding_mode='zeros')
107 |
108 | return transformed_image
109 |
110 |
111 | def matrix_normalization(self, M_cv):
112 | M_cv = M_cv.unsqueeze(0)
113 | B = M_cv.shape[0]
114 | H = self.height
115 | W = self.width
116 | device = self.device
117 |
118 | norm_matrix = torch.tensor([
119 | [2.0/W, 0, -1],
120 | [ 0, 2.0/H, -1],
121 | [ 0, 0, 1]
122 | ], dtype=torch.float32, device=device).unsqueeze(0).repeat(B, 1, 1)
123 |
124 | inv_norm_matrix = torch.tensor([
125 | [W/2.0, 0, W/2.0],
126 | [ 0, H/2.0, H/2.0],
127 | [ 0, 0, 1]
128 | ], dtype=torch.float32, device=device).unsqueeze(0).repeat(B, 1, 1)
129 |
130 | M_norm = torch.bmm(torch.bmm(norm_matrix, torch.inverse(M_cv)), inv_norm_matrix)
131 |
132 | return M_norm
133 |
134 | def get_M_2(self, theta, phi, gamma, dx, dy, dz):
135 | w = self.width
136 | h = self.height
137 | f = self.focal
138 |
139 | # Projection 2D -> 3D matrix
140 | A1 = torch.tensor([ [1, 0, -w/2],
141 | [0, 1, -h/2],
142 | [0, 0, 1],
143 | [0, 0, 1] ], dtype=torch.float32, device=self.device)
144 |
145 | # Rotation matrices around the X, Y, and Z axis
146 | RX = torch.tensor([ [1, 0, 0, 0],
147 | [0, torch.cos(theta), -torch.sin(theta), 0],
148 | [0, torch.sin(theta), torch.cos(theta), 0],
149 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device)
150 |
151 | RY = torch.tensor([ [torch.cos(phi), 0, -torch.sin(phi), 0],
152 | [0, 1, 0, 0],
153 | [torch.sin(phi), 0, torch.cos(phi), 0],
154 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device)
155 |
156 | RZ = torch.tensor([ [torch.cos(gamma), -torch.sin(gamma), 0, 0],
157 | [torch.sin(gamma), torch.cos(gamma), 0, 0],
158 | [0, 0, 1, 0],
159 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device)
160 |
161 | # Composed rotation matrix with (RX, RY, RZ)
162 | R = torch.matmul(torch.matmul(RX, RY), RZ)
163 |
164 | # Translation matrix
165 | T = torch.tensor([ [1, 0, 0, dx],
166 | [0, 1, 0, dy],
167 | [0, 0, 1, dz],
168 | [0, 0, 0, 1] ], dtype=torch.float32, device=self.device)
169 |
170 | # Projection 3D -> 2D matrix
171 | A2 = torch.tensor([ [f, 0, w/2, 0],
172 | [0, f, h/2, 0],
173 | [0, 0, 1, 0] ], dtype=torch.float32, device=self.device)
174 |
175 | # Final transformation matrix
176 | M = torch.matmul(A2, torch.matmul(T, torch.matmul(R, A1)))
177 |
178 | return M
179 |
180 |
181 | """ Get Perspective Projection Matrix """
182 | def get_M(self, theta, phi, gamma, dx, dy, dz):
183 | w = self.width
184 | h = self.height
185 | f = self.focal
186 |
187 | # Projection 2D -> 3D matrix
188 | A1 = np.array([ [1, 0, -w/2],
189 | [0, 1, -h/2],
190 | [0, 0, 1],
191 | [0, 0, 1]])
192 |
193 | # Rotation matrices around the X, Y, and Z axis
194 | RX = np.array([ [1, 0, 0, 0],
195 | [0, np.cos(theta), -np.sin(theta), 0],
196 | [0, np.sin(theta), np.cos(theta), 0],
197 | [0, 0, 0, 1]])
198 |
199 | RY = np.array([ [np.cos(phi), 0, -np.sin(phi), 0],
200 | [0, 1, 0, 0],
201 | [np.sin(phi), 0, np.cos(phi), 0],
202 | [0, 0, 0, 1]])
203 |
204 | RZ = np.array([ [np.cos(gamma), -np.sin(gamma), 0, 0],
205 | [np.sin(gamma), np.cos(gamma), 0, 0],
206 | [0, 0, 1, 0],
207 | [0, 0, 0, 1]])
208 |
209 | # Composed rotation matrix with (RX, RY, RZ)
210 | R = np.dot(np.dot(RX, RY), RZ)
211 |
212 | # Translation matrix
213 | T = np.array([ [1, 0, 0, dx],
214 | [0, 1, 0, dy],
215 | [0, 0, 1, dz],
216 | [0, 0, 0, 1]])
217 |
218 | # Projection 3D -> 2D matrix
219 | A2 = np.array([ [f, 0, w/2, 0],
220 | [0, f, h/2, 0],
221 | [0, 0, 1, 0]])
222 |
223 | # Final transformation matrix
224 | M = np.dot(A2, np.dot(T, np.dot(R, A1)))
225 |
226 | return M
--------------------------------------------------------------------------------
/unidemoire/models/shooting/method.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import random
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision.utils
10 | import torchvision.transforms as transforms
11 |
12 | from PIL import Image
13 |
14 | from unidemoire.models.shooting.mosaicing_demosaicing_v2 import *
15 | from unidemoire.models.shooting.image_transformer import ImageTransformer
16 |
17 | def adjust_contrast_and_brightness(input_img, beta = 30):
18 | beta = beta / 255.0 #* 亮度强度
19 | input_img = torch.clamp(input_img + beta, 0, 1)
20 |
21 | return input_img
22 |
23 | def simulate_LCD_display(input_img, device):
24 | """ Simulate the display of raw images on LCD screen
25 | Input:
26 | original images (tensor): batch x channel x height x width
27 | Output:
28 | LCD images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor)
29 | """
30 | b, c, h, w = input_img.shape
31 |
32 | simulate_imgs = torch.zeros((b, c, h * 3, w * 3), dtype=torch.float32, device=device)
33 | red = input_img[:, 0, :, :].repeat_interleave(3, dim=1)
34 | green = input_img[:, 1, :, :].repeat_interleave(3, dim=1)
35 | blue = input_img[:, 2, :, :].repeat_interleave(3, dim=1)
36 |
37 | simulate_imgs[:, 0, :, 0::3] = red
38 | simulate_imgs[:, 1, :, 1::3] = green
39 | simulate_imgs[:, 2, :, 2::3] = blue
40 |
41 | return simulate_imgs
42 |
43 |
44 | def demosaic_and_denoise(input_img, device):
45 | """ Apply demosaicing to the images
46 | Input:
47 | images (tensor): batch x (height x scale_factor) x (width x scale_factor)
48 | Output:
49 | demosaicing images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor)
50 | """
51 | input_img = input_img.double()
52 | demosaicing_imgs = demosaicing_CFA_Bayer_bilinear(input_img)
53 | demosaicing_imgs = demosaicing_imgs.permute(0, 3, 1, 2)
54 | return demosaicing_imgs
55 |
56 | def simulate_CFA(input_img, device):
57 | """ Simulate the raw reading of the camera sensor using bayer CFA
58 | Input:
59 | images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor)
60 | Output:
61 | mosaicing images (tensor): batch x (height x scale_factor) x (width x scale_factor)
62 | """
63 | input_img = input_img.permute(0, 2, 3, 1)
64 | mosaicing_imgs = mosaicing_CFA_Bayer(input_img)
65 | return mosaicing_imgs
66 |
67 | def random_rotation_3(org_images, lcd_images, device):
68 | """ Simulate the 3D rotatation during the shooting
69 | Input:
70 | images (tensor): batch x channel x height x width
71 | Rotate angle:
72 | theta (int): (-20, 20)
73 | phi (int): (-20, 20)
74 | gamma (int): (-20, 20)
75 | Output:
76 | rotated original images (tensor): batch x channel x height x width
77 | rotated LCD images (tensor): batch x channel x (height x scale_factor) x (width x scale_factor)
78 | """
79 | rotate_images = torch.zeros_like(org_images).to(device) # (bs, c, h, w)
80 | rotate_lcd_images = torch.zeros_like(lcd_images).to(device) # (bs, c, 3h, 3w)
81 |
82 | for n, img in enumerate(org_images):
83 |
84 | Trans_org = ImageTransformer(img)
85 | Trans_lcd = ImageTransformer(lcd_images[n])
86 |
87 | theta, phi, gamma, rotate_img = Trans_org.Perspective(random_f=True)
88 | _, _, _, rotate_lcd_img = Trans_lcd.Perspective(random_f=False, theta=theta, phi=phi, gamma=gamma)
89 |
90 | rotate_img = rotate_img.squeeze(0)
91 | rotate_lcd_img = rotate_lcd_img.squeeze(0)
92 |
93 | rotate_images[n, :] = rotate_img
94 | rotate_lcd_images[n, :] = rotate_lcd_img
95 |
96 | return rotate_images, rotate_lcd_images
97 |
98 |
99 | def Shooting(org_imgs, device):
100 | batch_size, channel, img_h, img_w = org_imgs.shape
101 | alpha = random.randint(1,4)
102 | crop_ratio = 0.7
103 |
104 | noise = torch.randn([batch_size, img_h * alpha * 3, img_w * alpha * 3]).to(device)
105 | noise = noise / 256.0
106 |
107 | resize_before_lcd = F.interpolate(org_imgs, scale_factor=alpha, mode="bilinear", align_corners=True)
108 | lcd_images = simulate_LCD_display(resize_before_lcd, device)
109 | rotate_images, rotate_lcd_images = random_rotation_3(org_imgs, lcd_images, device)
110 |
111 | cfa_img = simulate_CFA(rotate_lcd_images, device)
112 | cfa_img_noise = cfa_img + noise
113 | # cfa_img_noise = cfa_img_noise.double()
114 | demosaic_img = demosaic_and_denoise(cfa_img_noise, device)
115 | brighter_img = adjust_contrast_and_brightness(demosaic_img, beta=20)
116 |
117 | at_images = F.interpolate(brighter_img, [img_h, img_w], mode='bilinear', align_corners=True)
118 | at_images = torch.clamp(at_images, min=0, max=1)
119 |
120 | crop_edges = transforms.Compose([
121 | transforms.CenterCrop((int(img_h*crop_ratio), int(img_w*crop_ratio))),
122 | transforms.Resize((img_h, img_w)),
123 |
124 | ])
125 | rotate_images = crop_edges(rotate_images)
126 | at_images = crop_edges(at_images)
127 |
128 | return at_images, rotate_images
129 |
130 |
131 |
132 | trans = transforms.Compose([
133 | transforms.Resize((384,384)),
134 | transforms.ToTensor()
135 | ])
136 |
--------------------------------------------------------------------------------
/unidemoire/models/shooting/mosaicing_demosaicing_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from colour_demosaicing.bayer import masks_CFA_Bayer
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | def demosaicing_CFA_Bayer_bilinear(CFA, pattern='RGGB'):
8 | """
9 | Returns the demosaiced *RGB* colourspace array from given *Bayer* CFA using
10 | bilinear interpolation.
11 |
12 | Parameters
13 | ----------
14 | CFA : array_like
15 | *Bayer* CFA.
16 | pattern : unicode, optional
17 | **{'RGGB', 'BGGR', 'GRBG', 'GBRG'}**,
18 | Arrangement of the colour filters on the pixel array.
19 |
20 | Returns
21 | -------
22 | ndarray
23 | *RGB* colourspace array.
24 |
25 | Notes
26 | -----
27 | - The definition output is not clipped in range [0, 1] : this allows for
28 | direct HDRI / radiance image generation on *Bayer* CFA data and post
29 | demosaicing of the high dynamic range data as showcased in this
30 | `Jupyter Notebook `__.
33 |
34 | References
35 | ----------
36 | :cite:`Losson2010c`
37 |
38 | Examples
39 | --------
40 | >>> import numpy as np
41 | >>> CFA = np.array(
42 | ... [[0.30980393, 0.36078432, 0.30588236, 0.3764706],
43 | ... [0.35686275, 0.39607844, 0.36078432, 0.40000001]])
44 | >>> demosaicing_CFA_Bayer_bilinear(CFA)
45 | array([[[ 0.69705884, 0.17941177, 0.09901961],
46 | [ 0.46176472, 0.4509804 , 0.19803922],
47 | [ 0.45882354, 0.27450981, 0.19901961],
48 | [ 0.22941177, 0.5647059 , 0.30000001]],
49 |
50 | [[ 0.23235295, 0.53529412, 0.29705883],
51 | [ 0.15392157, 0.26960785, 0.59411766],
52 | [ 0.15294118, 0.4509804 , 0.59705884],
53 | [ 0.07647059, 0.18431373, 0.90000002]]])
54 | >>> CFA = np.array(
55 | ... [[0.3764706, 0.360784320, 0.40784314, 0.3764706],
56 | ... [0.35686275, 0.30980393, 0.36078432, 0.29803923]])
57 | >>> demosaicing_CFA_Bayer_bilinear(CFA, 'BGGR')
58 | array([[[ 0.07745098, 0.17941177, 0.84705885],
59 | [ 0.15490197, 0.4509804 , 0.5882353 ],
60 | [ 0.15196079, 0.27450981, 0.61176471],
61 | [ 0.22352942, 0.5647059 , 0.30588235]],
62 |
63 | [[ 0.23235295, 0.53529412, 0.28235295],
64 | [ 0.4647059 , 0.26960785, 0.19607843],
65 | [ 0.45588237, 0.4509804 , 0.20392157],
66 | [ 0.67058827, 0.18431373, 0.10196078]]])
67 | """
68 |
69 | ## Above is the original version on mosaicing_demosaicing package processing image based on numpy arrays, we adapt it to a torch tensor version as follows:
70 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71 | batch, h, w= CFA.size()
72 |
73 | R_m, G_m, B_m = masks_CFA_Bayer([h, w], pattern)
74 |
75 | R_m = R_m[np.newaxis, np.newaxis, :]
76 | R_m = np.repeat(R_m, batch, axis = 0)
77 | G_m = G_m[np.newaxis, np.newaxis, :]
78 | G_m = np.repeat(G_m, batch, axis=0)
79 | B_m = B_m[np.newaxis, np.newaxis, :]
80 | B_m = np.repeat(B_m, batch, axis=0)
81 |
82 | R_m = torch.from_numpy(R_m).to(device)
83 | G_m = torch.from_numpy(G_m).to(device)
84 | B_m = torch.from_numpy(B_m).to(device)
85 |
86 | H_G = np.array(
87 | [[0, 1, 0],
88 | [1, 4, 1],
89 | [0, 1, 0]]) / 4 # yapf: disable
90 |
91 | H_G = H_G[np.newaxis, np.newaxis, :]
92 | H_G = torch.from_numpy(H_G).to(device)
93 |
94 | H_RB = np.array(
95 | [[1, 2, 1],
96 | [2, 4, 2],
97 | [1, 2, 1]]) / 4 # yapf: disable
98 |
99 | H_RB = H_RB[np.newaxis, np.newaxis, :]
100 | H_RB = torch.from_numpy(H_RB).to(device)
101 | CFA = CFA.unsqueeze(1)
102 |
103 | R = F.conv2d(CFA * R_m, H_RB, stride=1, padding=1)
104 | G = F.conv2d(CFA * G_m, H_G, stride=1, padding=1)
105 | B = F.conv2d(CFA * B_m, H_RB, stride=1, padding=1)
106 |
107 | R = R.squeeze(1)
108 | G = G.squeeze(1)
109 | B = B.squeeze(1)
110 |
111 | del R_m, G_m, B_m, H_RB, H_G
112 | torch.cuda.empty_cache()
113 |
114 | return torch.stack((R, G, B), dim = 3)
115 |
116 | def mosaicing_CFA_Bayer(RGB, pattern = 'RGGB'):
117 | """
118 | Returns the *Bayer* CFA mosaic for a given *RGB* colourspace array.
119 |
120 | Parameters
121 | ----------
122 | RGB : array_like
123 | *RGB* colourspace array.
124 | pattern : unicode, optional
125 | **{'RGGB', 'BGGR', 'GRBG', 'GBRG'}**,
126 | Arrangement of the colour filters on the pixel array.
127 |
128 | Returns
129 | -------
130 | ndarray
131 | *Bayer* CFA mosaic.
132 |
133 | Examples
134 | --------
135 | >>> import numpy as np
136 | >>> RGB = np.array([[[0, 1, 2],
137 | ... [0, 1, 2]],
138 | ... [[0, 1, 2],
139 | ... [0, 1, 2]]])
140 | >>> mosaicing_CFA_Bayer(RGB)
141 | array([[ 0., 1.],
142 | [ 1., 2.]])
143 | >>> mosaicing_CFA_Bayer(RGB, pattern='BGGR')
144 | array([[ 2., 1.],
145 | [ 1., 0.]])
146 | """
147 |
148 | ## Above is the original version on mosaicing_demosaicing package processing image based on numpy arrays, we adapt it to a torch tensor version as follows:
149 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150 |
151 | R = RGB[:, :, :, 0]
152 | G = RGB[:, :, :, 1]
153 | B = RGB[:, :, :, 2]
154 |
155 | batch, _, _, _ = RGB.shape
156 | R_m, G_m, B_m = masks_CFA_Bayer(RGB.shape[1:3], pattern)
157 |
158 | G_m = G_m[np.newaxis, :]
159 | G_m = np.repeat(G_m, batch, axis = 0)
160 | B_m = B_m[np.newaxis, :]
161 | B_m = np.repeat(B_m, batch, axis = 0)
162 | R_m = R_m[np.newaxis, :]
163 | R_m = np.repeat(R_m, batch, axis = 0)
164 |
165 | R_m = torch.from_numpy(R_m).to(device)
166 | G_m = torch.from_numpy(G_m).to(device)
167 | B_m = torch.from_numpy(B_m).to(device)
168 |
169 | CFA = R * R_m + G * G_m + B * B_m
170 | del R_m, G_m, B_m
171 | torch.cuda.empty_cache()
172 |
173 | return CFA
174 |
175 |
--------------------------------------------------------------------------------
/unidemoire/models/undem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/undem/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/models/utils/__init__.py
--------------------------------------------------------------------------------
/unidemoire/models/utils/matlab_ssim.py:
--------------------------------------------------------------------------------
1 | """
2 | A pytorch implementation for reproducing results in MATLAB, slightly modified from
3 | https://github.com/mayorx/matlab_ssim_pytorch_implementation.
4 | """
5 |
6 | import torch
7 | import cv2
8 | import numpy as np
9 |
10 | def generate_1d_gaussian_kernel():
11 | return cv2.getGaussianKernel(11, 1.5)
12 |
13 | def generate_2d_gaussian_kernel():
14 | kernel = generate_1d_gaussian_kernel()
15 | return np.outer(kernel, kernel.transpose())
16 |
17 | def generate_3d_gaussian_kernel():
18 | kernel = generate_1d_gaussian_kernel()
19 | window = generate_2d_gaussian_kernel()
20 | return np.stack([window * k for k in kernel], axis=0)
21 |
22 | class MATLAB_SSIM(torch.nn.Module):
23 | def __init__(self, device='cpu'):
24 | super(MATLAB_SSIM, self).__init__()
25 | self.device = device
26 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
27 | conv3d.weight.requires_grad = False
28 | conv3d.weight[0, 0, :, :, :] = torch.tensor(generate_3d_gaussian_kernel())
29 | self.conv3d = conv3d.to(device)
30 |
31 | conv2d = torch.nn.Conv2d(1, 1, (11, 11), stride=1, padding=(5, 5), bias=False, padding_mode='replicate')
32 | conv2d.weight.requires_grad = False
33 | conv2d.weight[0, 0, :, :] = torch.tensor(generate_2d_gaussian_kernel())
34 | self.conv2d = conv2d.to(device)
35 |
36 | def forward(self, img1, img2, device='cuda'):
37 | assert len(img1.shape) == len(img2.shape)
38 | self.device = device
39 | self.conv2d = self.conv2d.to(self.device)
40 | self.conv3d = self.conv3d.to(self.device)
41 | with torch.no_grad():
42 | img1 = torch.tensor(img1).to(self.device).float()
43 | img2 = torch.tensor(img2).to(self.device).float()
44 |
45 | if len(img1.shape) == 2:
46 | conv = self.conv2d
47 | elif len(img1.shape) == 3:
48 | conv = self.conv3d
49 | else:
50 | raise not NotImplementedError('only support 2d / 3d images.')
51 | return self._ssim(img1, img2, conv)
52 |
53 | def _ssim(self, img1, img2, conv):
54 | img1 = img1.unsqueeze(0).unsqueeze(0)
55 | img2 = img2.unsqueeze(0).unsqueeze(0)
56 |
57 | C1 = (0.01 * 255) ** 2
58 | C2 = (0.03 * 255) ** 2
59 |
60 | mu1 = conv(img1)
61 | mu2 = conv(img2)
62 |
63 | mu1_sq = mu1 ** 2
64 | mu2_sq = mu2 ** 2
65 | mu1_mu2 = mu1 * mu2
66 | sigma1_sq = conv(img1 ** 2) - mu1_sq
67 | sigma2_sq = conv(img2 ** 2) - mu2_sq
68 | sigma12 = conv(img1 * img2) - mu1_mu2
69 |
70 | ssim_map = ((2 * mu1_mu2 + C1) *
71 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
72 | (sigma1_sq + sigma2_sq + C2))
73 |
74 | return float(ssim_map.mean())
75 |
--------------------------------------------------------------------------------
/unidemoire/models/utils/metric.py:
--------------------------------------------------------------------------------
1 | from .common import SSIM, PSNR, tensor2img
2 | from skimage.metrics import peak_signal_noise_ratio as ski_psnr
3 | from skimage.metrics import structural_similarity as ski_ssim
4 | from unidemoire.models.utils.matlab_ssim import MATLAB_SSIM
5 | import lpips
6 | import torch
7 | import numpy as np
8 | from math import log10
9 |
10 | class create_metrics():
11 | """
12 | We note that for different benchmarks, previous works calculate metrics in different ways, which might
13 | lead to inconsistent SSIM results (and slightly different PSNR), and thus we follow their individual
14 | ways to compute metrics on each individual dataset for fair comparisons.
15 | For our 4K dataset, calculating metrics for 4k image is much time-consuming,
16 | thus we benchmark evaluations for all methods with a fast pytorch SSIM implementation referred from
17 | "https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py".
18 | """
19 | def __init__(self, dataset, device):
20 | self.data_type = dataset
21 | # self.lpips_fn = lpips.LPIPS(net='alex').cuda()
22 | self.lpips_fn = lpips.LPIPS(net='alex')
23 | # self.lpips_fn = self.lpips_fn.to(device)
24 | self.fast_ssim = SSIM()
25 | self.fast_psnr = PSNR()
26 | self.matlab_ssim = MATLAB_SSIM(device=device)
27 |
28 | def compute(self, out_img, gt, device=None):
29 | if self.data_type == 'UHDM':
30 | res_psnr, res_ssim = self.fast_psnr_ssim(out_img, gt)
31 | elif self.data_type == 'FHDMi':
32 | res_psnr, res_ssim = self.skimage_psnr_ssim(out_img, gt)
33 | elif self.data_type == 'TIP':
34 | res_psnr, res_ssim = self.matlab_psnr_ssim(out_img, gt, device)
35 | elif self.data_type == 'AIM':
36 | res_psnr, res_ssim = self.aim_psnr_ssim(out_img, gt)
37 | else:
38 | print('Unrecognized data_type for evaluation!')
39 | raise NotImplementedError
40 | pre = torch.clamp(out_img, min=0, max=1)
41 | tar = torch.clamp(gt, min=0, max=1)
42 | self.lpips_fn = self.lpips_fn.to(device)
43 | res_lpips = self.lpips_fn.forward(pre, tar, normalize=True).item()
44 |
45 | return res_lpips, res_psnr, res_ssim
46 |
47 |
48 | def fast_psnr_ssim(self, out_img, gt):
49 | pre = torch.clamp(out_img, min=0, max=1)
50 | tar = torch.clamp(gt, min=0, max=1)
51 | psnr = self.fast_psnr(pre, tar)
52 | ssim = self.fast_ssim(pre, tar)
53 | return psnr, ssim
54 |
55 | def skimage_psnr_ssim(self, out_img, gt):
56 | """
57 | Same with the previous SOTA FHDe2Net: https://github.com/PKU-IMRE/FHDe2Net/blob/main/test.py
58 | """
59 | mi1 = tensor2img(out_img)
60 | mt1 = tensor2img(gt)
61 |
62 | psnr = ski_psnr(mt1, mi1)
63 | ssim = ski_ssim(mt1, mi1, multichannel=True, channel_axis=2)
64 | return psnr, ssim
65 |
66 | def matlab_psnr_ssim(self, out_img, gt, device):
67 | """
68 | A pytorch implementation for reproducing SSIM results when using MATLAB
69 | same with the previous SOTA MopNet: https://github.com/PKU-IMRE/MopNet/blob/master/test_with_matlabcode.m
70 | """
71 | mi1 = tensor2img(out_img)
72 | mt1 = tensor2img(gt)
73 | psnr = ski_psnr(mt1, mi1)
74 | ssim = self.matlab_ssim(mt1, mi1, device)
75 | return psnr, ssim
76 |
77 | def aim_psnr_ssim(self, out_img, gt):
78 | """
79 | Same with the previous SOTA MBCNN: https://github.com/zhenngbolun/Learnbale_Bandpass_Filter/blob/master/main_multiscale.py
80 | """
81 | mi1 = tensor2img(out_img)
82 | mt1 = tensor2img(gt)
83 | mi1 = mi1.astype(np.float32) / 255.0
84 | mt1 = mt1.astype(np.float32) / 255.0
85 | psnr = 10 * log10(1 / np.mean((mt1 - mi1) ** 2))
86 | ssim = ski_ssim(mt1, mi1, multichannel=True)
87 | return psnr, ssim
--------------------------------------------------------------------------------
/unidemoire/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from unidemoire.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/unidemoire/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/unidemoire/modules/diffusionmodules/util.py:
--------------------------------------------------------------------------------
1 | # adopted from
2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3 | # and
4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5 | # and
6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7 | #
8 | # thanks!
9 |
10 |
11 | import os
12 | import math
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 | from einops import repeat
17 |
18 | from unidemoire.util import instantiate_from_config
19 |
20 |
21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22 | if schedule == "linear":
23 | betas = (
24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25 | )
26 |
27 | elif schedule == "cosine":
28 | timesteps = (
29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30 | )
31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
32 | alphas = torch.cos(alphas).pow(2)
33 | alphas = alphas / alphas[0]
34 | betas = 1 - alphas[1:] / alphas[:-1]
35 | betas = np.clip(betas, a_min=0, a_max=0.999)
36 |
37 | elif schedule == "sqrt_linear":
38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39 | elif schedule == "sqrt":
40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41 | else:
42 | raise ValueError(f"schedule '{schedule}' unknown.")
43 | return betas.numpy()
44 |
45 |
46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47 | if ddim_discr_method == 'uniform':
48 | c = num_ddpm_timesteps // num_ddim_timesteps
49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50 | elif ddim_discr_method == 'quad':
51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52 | else:
53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54 |
55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56 | # add one to get the final alpha values right (the ones from first scale to data during sampling)
57 | steps_out = ddim_timesteps + 1
58 | if verbose:
59 | print(f'Selected timesteps for ddim sampler: {steps_out}')
60 | return steps_out
61 |
62 |
63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64 | # select alphas for computing the variance schedule
65 | alphas = alphacums[ddim_timesteps]
66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67 |
68 | # according the the formula provided in https://arxiv.org/abs/2010.02502
69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70 | if verbose:
71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72 | print(f'For the chosen value of eta, which is {eta}, '
73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74 | return sigmas, alphas, alphas_prev
75 |
76 |
77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78 | """
79 | Create a beta schedule that discretizes the given alpha_t_bar function,
80 | which defines the cumulative product of (1-beta) over time from t = [0,1].
81 | :param num_diffusion_timesteps: the number of betas to produce.
82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83 | produces the cumulative product of (1-beta) up to that
84 | part of the diffusion process.
85 | :param max_beta: the maximum beta to use; use values lower than 1 to
86 | prevent singularities.
87 | """
88 | betas = []
89 | for i in range(num_diffusion_timesteps):
90 | t1 = i / num_diffusion_timesteps
91 | t2 = (i + 1) / num_diffusion_timesteps
92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93 | return np.array(betas)
94 |
95 |
96 | def extract_into_tensor(a, t, x_shape):
97 | b, *_ = t.shape
98 | out = a.gather(-1, t)
99 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100 |
101 |
102 | def checkpoint(func, inputs, params, flag):
103 | """
104 | Evaluate a function without caching intermediate activations, allowing for
105 | reduced memory at the expense of extra compute in the backward pass.
106 | :param func: the function to evaluate.
107 | :param inputs: the argument sequence to pass to `func`.
108 | :param params: a sequence of parameters `func` depends on but does not
109 | explicitly take as arguments.
110 | :param flag: if False, disable gradient checkpointing.
111 | """
112 | if flag:
113 | args = tuple(inputs) + tuple(params)
114 | return CheckpointFunction.apply(func, len(inputs), *args)
115 | else:
116 | return func(*inputs)
117 |
118 |
119 | class CheckpointFunction(torch.autograd.Function):
120 | @staticmethod
121 | def forward(ctx, run_function, length, *args):
122 | ctx.run_function = run_function
123 | ctx.input_tensors = list(args[:length])
124 | ctx.input_params = list(args[length:])
125 |
126 | with torch.no_grad():
127 | output_tensors = ctx.run_function(*ctx.input_tensors)
128 | return output_tensors
129 |
130 | @staticmethod
131 | def backward(ctx, *output_grads):
132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
133 | with torch.enable_grad():
134 | # Fixes a bug where the first op in run_function modifies the
135 | # Tensor storage in place, which is not allowed for detach()'d
136 | # Tensors.
137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
138 | output_tensors = ctx.run_function(*shallow_copies)
139 | input_grads = torch.autograd.grad(
140 | output_tensors,
141 | ctx.input_tensors + ctx.input_params,
142 | output_grads,
143 | allow_unused=True,
144 | )
145 | del ctx.input_tensors
146 | del ctx.input_params
147 | del output_tensors
148 | return (None, None) + input_grads
149 |
150 |
151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
152 | """
153 | Create sinusoidal timestep embeddings.
154 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
155 | These may be fractional.
156 | :param dim: the dimension of the output.
157 | :param max_period: controls the minimum frequency of the embeddings.
158 | :return: an [N x dim] Tensor of positional embeddings.
159 | """
160 | if not repeat_only:
161 | half = dim // 2
162 | freqs = torch.exp(
163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
164 | ).to(device=timesteps.device)
165 | args = timesteps[:, None].float() * freqs[None]
166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
167 | if dim % 2:
168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
169 | else:
170 | embedding = repeat(timesteps, 'b -> b d', d=dim)
171 | return embedding
172 |
173 |
174 | def zero_module(module):
175 | """
176 | Zero out the parameters of a module and return it.
177 | """
178 | for p in module.parameters():
179 | p.detach().zero_()
180 | return module
181 |
182 |
183 | def scale_module(module, scale):
184 | """
185 | Scale the parameters of a module and return it.
186 | """
187 | for p in module.parameters():
188 | p.detach().mul_(scale)
189 | return module
190 |
191 |
192 | def mean_flat(tensor):
193 | """
194 | Take the mean over all non-batch dimensions.
195 | """
196 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
197 |
198 |
199 | def normalization(channels):
200 | """
201 | Make a standard normalization layer.
202 | :param channels: number of input channels.
203 | :return: an nn.Module for normalization.
204 | """
205 | return GroupNorm32(32, channels)
206 |
207 |
208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
209 | class SiLU(nn.Module):
210 | def forward(self, x):
211 | return x * torch.sigmoid(x)
212 |
213 |
214 | class GroupNorm32(nn.GroupNorm):
215 | def forward(self, x):
216 | return super().forward(x.float()).type(x.dtype)
217 |
218 | def conv_nd(dims, *args, **kwargs):
219 | """
220 | Create a 1D, 2D, or 3D convolution module.
221 | """
222 | if dims == 1:
223 | return nn.Conv1d(*args, **kwargs)
224 | elif dims == 2:
225 | return nn.Conv2d(*args, **kwargs)
226 | elif dims == 3:
227 | return nn.Conv3d(*args, **kwargs)
228 | raise ValueError(f"unsupported dimensions: {dims}")
229 |
230 |
231 | def linear(*args, **kwargs):
232 | """
233 | Create a linear module.
234 | """
235 | return nn.Linear(*args, **kwargs)
236 |
237 |
238 | def avg_pool_nd(dims, *args, **kwargs):
239 | """
240 | Create a 1D, 2D, or 3D average pooling module.
241 | """
242 | if dims == 1:
243 | return nn.AvgPool1d(*args, **kwargs)
244 | elif dims == 2:
245 | return nn.AvgPool2d(*args, **kwargs)
246 | elif dims == 3:
247 | return nn.AvgPool3d(*args, **kwargs)
248 | raise ValueError(f"unsupported dimensions: {dims}")
249 |
250 |
251 | class HybridConditioner(nn.Module):
252 |
253 | def __init__(self, c_concat_config, c_crossattn_config):
254 | super().__init__()
255 | self.concat_conditioner = instantiate_from_config(c_concat_config)
256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
257 |
258 | def forward(self, c_concat, c_crossattn):
259 | c_concat = self.concat_conditioner(c_concat)
260 | c_crossattn = self.crossattn_conditioner(c_crossattn)
261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
262 |
263 |
264 | def noise_like(shape, device, repeat=False):
265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
266 | noise = lambda: torch.randn(shape, device=device)
267 | return repeat_noise() if repeat else noise()
--------------------------------------------------------------------------------
/unidemoire/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/unidemoire/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/unidemoire/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/unidemoire/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/4DVLab/UniDemoire/4221a0e98f0078c3f72ef1da8128642e7f02662a/unidemoire/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/unidemoire/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import rearrange, repeat
6 | import kornia
7 |
8 |
9 | from unidemoire.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class TransformerEmbedder(AbstractEncoder):
37 | """Some transformer encoder layers"""
38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
39 | super().__init__()
40 | self.device = device
41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
42 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
43 |
44 | def forward(self, tokens):
45 | tokens = tokens.to(self.device) # meh
46 | z = self.transformer(tokens, return_embeddings=True)
47 | return z
48 |
49 | def encode(self, x):
50 | return self(x)
51 |
52 |
53 | class BERTTokenizer(AbstractEncoder):
54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
55 | def __init__(self, device="cuda", vq_interface=True, max_length=77):
56 | super().__init__()
57 | from transformers import BertTokenizerFast # TODO: add to reuquirements
58 | self.tokenizer = BertTokenizerFast.from_pretrained("models/bert")
59 |
60 | self.device = device
61 | self.vq_interface = vq_interface
62 | self.max_length = max_length
63 |
64 | def forward(self, text):
65 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
66 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
67 | tokens = batch_encoding["input_ids"].to(self.device)
68 | return tokens
69 |
70 | @torch.no_grad()
71 | def encode(self, text):
72 | tokens = self(text)
73 | if not self.vq_interface:
74 | return tokens
75 | return None, None, [None, None, tokens]
76 |
77 | def decode(self, text):
78 | return text
79 |
80 |
81 | class BERTEmbedder(AbstractEncoder):
82 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
83 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
84 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
85 | super().__init__()
86 | self.use_tknz_fn = use_tokenizer
87 | if self.use_tknz_fn:
88 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
89 | self.device = device
90 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
91 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
92 | emb_dropout=embedding_dropout)
93 |
94 | def forward(self, text):
95 | if self.use_tknz_fn:
96 | tokens = self.tknz_fn(text)#.to(self.device)
97 | else:
98 | tokens = text
99 | z = self.transformer(tokens, return_embeddings=True)
100 | return z
101 |
102 | def encode(self, text):
103 | # output of length 77
104 | return self(text)
105 |
106 |
107 | class SpatialRescaler(nn.Module):
108 | def __init__(self,
109 | n_stages=1,
110 | method='bilinear',
111 | multiplier=0.5,
112 | in_channels=3,
113 | out_channels=None,
114 | bias=False):
115 | super().__init__()
116 | self.n_stages = n_stages
117 | assert self.n_stages >= 0
118 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
119 | self.multiplier = multiplier
120 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
121 | self.remap_output = out_channels is not None
122 | if self.remap_output:
123 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
124 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
125 |
126 | def forward(self,x):
127 | for stage in range(self.n_stages):
128 | x = self.interpolator(x, scale_factor=self.multiplier)
129 |
130 |
131 | if self.remap_output:
132 | x = self.channel_mapper(x)
133 | return x
134 |
135 | def encode(self, x):
136 | return self(x)
137 |
138 |
139 | class FrozenCLIPTextEmbedder(nn.Module):
140 | """
141 | Uses the CLIP transformer encoder for text.
142 | """
143 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
144 | super().__init__()
145 | self.model, _ = clip.load(version, jit=False, device="cpu")
146 | self.device = device
147 | self.max_length = max_length
148 | self.n_repeat = n_repeat
149 | self.normalize = normalize
150 |
151 | def freeze(self):
152 | self.model = self.model.eval()
153 | for param in self.parameters():
154 | param.requires_grad = False
155 |
156 | def forward(self, text):
157 | tokens = clip.tokenize(text).to(self.device)
158 | z = self.model.encode_text(tokens)
159 | if self.normalize:
160 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
161 | return z
162 |
163 | def encode(self, text):
164 | z = self(text)
165 | if z.ndim==2:
166 | z = z[:, None, :]
167 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
168 | return z
169 |
170 |
171 | class FrozenClipImageEmbedder(nn.Module):
172 | """
173 | Uses the CLIP image encoder.
174 | """
175 | def __init__(
176 | self,
177 | model,
178 | jit=False,
179 | device='cuda' if torch.cuda.is_available() else 'cpu',
180 | antialias=False,
181 | ):
182 | super().__init__()
183 | self.model, _ = clip.load(name=model, device=device, jit=jit)
184 |
185 | self.antialias = antialias
186 |
187 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
188 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
189 |
190 | def preprocess(self, x):
191 | # normalize to [0,1]
192 | x = kornia.geometry.resize(x, (224, 224),
193 | interpolation='bicubic',align_corners=True,
194 | antialias=self.antialias)
195 | x = (x + 1.) / 2.
196 | # renormalize according to clip
197 | x = kornia.enhance.normalize(x, self.mean, self.std)
198 | return x
199 |
200 | def forward(self, x):
201 | # x is assumed to be in range [-1,1]
202 | return self.model.encode_image(self.preprocess(x))
203 |
204 |
--------------------------------------------------------------------------------
/unidemoire/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from unidemoire.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/unidemoire/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 |
49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
50 | if self.perceptual_weight > 0:
51 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
52 | rec_loss = rec_loss + self.perceptual_weight * p_loss
53 |
54 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
55 | weighted_nll_loss = nll_loss
56 | if weights is not None:
57 | weighted_nll_loss = weights*nll_loss
58 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
59 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
60 |
61 | kl_loss = posteriors.kl()
62 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
63 |
64 | # now the GAN part
65 | if optimizer_idx == 0:
66 | # generator update
67 | if cond is None:
68 | assert not self.disc_conditional
69 | logits_fake = self.discriminator(reconstructions.contiguous())
70 | else:
71 | assert self.disc_conditional
72 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
73 | g_loss = -torch.mean(logits_fake)
74 |
75 | if self.disc_factor > 0.0:
76 | try:
77 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
78 | except RuntimeError:
79 | assert not self.training
80 | d_weight = torch.tensor(0.0)
81 | else:
82 | d_weight = torch.tensor(0.0)
83 |
84 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
85 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
86 |
87 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
88 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
89 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
90 | "{}/d_weight".format(split): d_weight.detach(),
91 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
92 | "{}/g_loss".format(split): g_loss.detach().mean(),
93 | }
94 | return loss, log
95 |
96 | if optimizer_idx == 1:
97 | # second pass for discriminator update
98 | if cond is None:
99 | logits_real = self.discriminator(inputs.contiguous().detach())
100 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
101 | else:
102 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
103 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
104 |
105 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
106 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
107 |
108 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
109 | "{}/logits_real".format(split): logits_real.detach().mean(),
110 | "{}/logits_fake".format(split): logits_fake.detach().mean()
111 | }
112 | return d_loss, log
113 |
114 |
--------------------------------------------------------------------------------
/unidemoire/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 | from unidemoire.util import exists
11 |
12 |
13 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
14 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
15 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
16 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
17 | loss_real = (weights * loss_real).sum() / weights.sum()
18 | loss_fake = (weights * loss_fake).sum() / weights.sum()
19 | d_loss = 0.5 * (loss_real + loss_fake)
20 | return d_loss
21 |
22 | def adopt_weight(weight, global_step, threshold=0, value=0.):
23 | if global_step < threshold:
24 | weight = value
25 | return weight
26 |
27 |
28 | def measure_perplexity(predicted_indices, n_embed):
29 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
30 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
31 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
32 | avg_probs = encodings.mean(0)
33 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
34 | cluster_use = torch.sum(avg_probs > 0)
35 | return perplexity, cluster_use
36 |
37 | def l1(x, y):
38 | return torch.abs(x-y)
39 |
40 |
41 | def l2(x, y):
42 | return torch.pow((x-y), 2)
43 |
44 |
45 | class VQLPIPSWithDiscriminator(nn.Module):
46 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
47 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
48 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
49 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
50 | pixel_loss="l1"):
51 | super().__init__()
52 | assert disc_loss in ["hinge", "vanilla"]
53 | assert perceptual_loss in ["lpips", "clips", "dists"]
54 | assert pixel_loss in ["l1", "l2"]
55 | self.codebook_weight = codebook_weight
56 | self.pixel_weight = pixelloss_weight
57 | if perceptual_loss == "lpips":
58 | print(f"{self.__class__.__name__}: Running with LPIPS.")
59 | self.perceptual_loss = LPIPS().eval()
60 | else:
61 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
62 | self.perceptual_weight = perceptual_weight
63 |
64 | if pixel_loss == "l1":
65 | self.pixel_loss = l1
66 | else:
67 | self.pixel_loss = l2
68 |
69 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
70 | n_layers=disc_num_layers,
71 | use_actnorm=use_actnorm,
72 | ndf=disc_ndf
73 | ).apply(weights_init)
74 | self.discriminator_iter_start = disc_start
75 | if disc_loss == "hinge":
76 | self.disc_loss = hinge_d_loss
77 | elif disc_loss == "vanilla":
78 | self.disc_loss = vanilla_d_loss
79 | else:
80 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
81 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
82 | self.disc_factor = disc_factor
83 | self.discriminator_weight = disc_weight
84 | self.disc_conditional = disc_conditional
85 | self.n_classes = n_classes
86 |
87 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
88 | if last_layer is not None:
89 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
90 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
91 | else:
92 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
93 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
94 |
95 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
96 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
97 | d_weight = d_weight * self.discriminator_weight
98 | return d_weight
99 |
100 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
101 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
102 | if not exists(codebook_loss):
103 | codebook_loss = torch.tensor([0.]).to(inputs.device)
104 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
105 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
106 | if self.perceptual_weight > 0:
107 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
108 | rec_loss = rec_loss + self.perceptual_weight * p_loss
109 | else:
110 | p_loss = torch.tensor([0.0])
111 |
112 | nll_loss = rec_loss
113 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
114 | nll_loss = torch.mean(nll_loss)
115 |
116 | # now the GAN part
117 | if optimizer_idx == 0:
118 | # generator update
119 | if cond is None:
120 | assert not self.disc_conditional
121 | logits_fake = self.discriminator(reconstructions.contiguous())
122 | else:
123 | assert self.disc_conditional
124 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
125 | g_loss = -torch.mean(logits_fake)
126 |
127 | try:
128 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
129 | except RuntimeError:
130 | assert not self.training
131 | d_weight = torch.tensor(0.0)
132 |
133 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
134 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
135 |
136 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
137 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
138 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
139 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
140 | "{}/p_loss".format(split): p_loss.detach().mean(),
141 | "{}/d_weight".format(split): d_weight.detach(),
142 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
143 | "{}/g_loss".format(split): g_loss.detach().mean(),
144 | }
145 | if predicted_indices is not None:
146 | assert self.n_classes is not None
147 | with torch.no_grad():
148 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
149 | log[f"{split}/perplexity"] = perplexity
150 | log[f"{split}/cluster_usage"] = cluster_usage
151 | return loss, log
152 |
153 | if optimizer_idx == 1:
154 | # second pass for discriminator update
155 | if cond is None:
156 | logits_real = self.discriminator(inputs.contiguous().detach())
157 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
158 | else:
159 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
160 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
161 |
162 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
163 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
164 |
165 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
166 | "{}/logits_real".format(split): logits_real.detach().mean(),
167 | "{}/logits_fake".format(split): logits_fake.detach().mean()
168 | }
169 | return d_loss, log
170 |
--------------------------------------------------------------------------------
/unidemoire/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import sys
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 |
8 | import multiprocessing as mp
9 | from threading import Thread
10 | from queue import Queue
11 | from inspect import isfunction
12 | from PIL import Image, ImageDraw, ImageFont
13 |
14 | def log_txt_as_img(wh, xc, size=10):
15 | # wh a tuple of (width, height)
16 | # xc a list of captions to plot
17 | b = len(xc)
18 | txts = list()
19 | for bi in range(b):
20 | txt = Image.new("RGB", wh, color="white")
21 | draw = ImageDraw.Draw(txt)
22 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
23 | nc = int(40 * (wh[0] / 256))
24 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
25 |
26 | try:
27 | draw.text((0, 0), lines, fill="black", font=font)
28 | except UnicodeEncodeError:
29 | print("Cant encode string for logging. Skipping.")
30 |
31 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
32 | txts.append(txt)
33 | txts = np.stack(txts)
34 | txts = torch.tensor(txts)
35 | return txts
36 |
37 |
38 | def ismap(x):
39 | if not isinstance(x, torch.Tensor):
40 | return False
41 | return (len(x.shape) == 4) and (x.shape[1] > 3)
42 |
43 |
44 | def isimage(x):
45 | if not isinstance(x, torch.Tensor):
46 | return False
47 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
48 |
49 |
50 | def exists(x):
51 | return x is not None
52 |
53 |
54 | def default(val, d):
55 | if exists(val):
56 | return val
57 | return d() if isfunction(d) else d
58 |
59 |
60 | def mean_flat(tensor):
61 | """
62 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
63 | Take the mean over all non-batch dimensions.
64 | """
65 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
66 |
67 |
68 | def count_params(model, verbose=False):
69 | total_params = sum(p.numel() for p in model.parameters())
70 | if verbose:
71 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
72 | return total_params
73 |
74 | def instantiate_from_config(config):
75 | if not "target" in config:
76 | if config == '__is_first_stage__':
77 | return None
78 | elif config == "__is_unconditional__":
79 | return None
80 | raise KeyError("Expected key `target` to instantiate.")
81 | target_cls = get_obj_from_str(config["target"])
82 | if target_cls is None:
83 | print(f"Warning: Target class {config['target']} not found, skipping instantiation.")
84 | return None
85 | return target_cls(**config.get("params", dict()))
86 |
87 |
88 |
89 | def get_obj_from_str(string, reload=False, silent=True):
90 | # module, cls = string.rsplit(".", 1)
91 | # if reload:
92 | # module_imp = importlib.port_module(module)
93 | # importlib.reload(module_imp)
94 | # return getattr(importlib.import_module(module, package=None), cls)
95 | try:
96 | module, cls = string.rsplit(".", 1)
97 | if reload:
98 | if module in sys.modules:
99 | importlib.reload(sys.modules[module])
100 | return getattr(importlib.import_module(module, package=None), cls)
101 | except (ModuleNotFoundError, AttributeError) as e:
102 | if not silent:
103 | raise
104 | print(f"Warning: Could not import {string} - {str(e)}. Skipping...")
105 | return None
106 |
107 |
108 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
109 | # create dummy dataset instance
110 |
111 | # run prefetching
112 | if idx_to_fn:
113 | res = func(data, worker_id=idx)
114 | else:
115 | res = func(data)
116 | Q.put([idx, res])
117 | Q.put("Done")
118 |
119 |
120 | def parallel_data_prefetch(
121 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
122 | ):
123 | # if target_data_type not in ["ndarray", "list"]:
124 | # raise ValueError(
125 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
126 | # )
127 | if isinstance(data, np.ndarray) and target_data_type == "list":
128 | raise ValueError("list expected but function got ndarray.")
129 | elif isinstance(data, abc.Iterable):
130 | if isinstance(data, dict):
131 | print(
132 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
133 | )
134 | data = list(data.values())
135 | if target_data_type == "ndarray":
136 | data = np.asarray(data)
137 | else:
138 | data = list(data)
139 | else:
140 | raise TypeError(
141 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
142 | )
143 |
144 | if cpu_intensive:
145 | Q = mp.Queue(1000)
146 | proc = mp.Process
147 | else:
148 | Q = Queue(1000)
149 | proc = Thread
150 | # spawn processes
151 | if target_data_type == "ndarray":
152 | arguments = [
153 | [func, Q, part, i, use_worker_id]
154 | for i, part in enumerate(np.array_split(data, n_proc))
155 | ]
156 | else:
157 | step = (
158 | int(len(data) / n_proc + 1)
159 | if len(data) % n_proc != 0
160 | else int(len(data) / n_proc)
161 | )
162 | arguments = [
163 | [func, Q, part, i, use_worker_id]
164 | for i, part in enumerate(
165 | [data[i: i + step] for i in range(0, len(data), step)]
166 | )
167 | ]
168 | processes = []
169 | for i in range(n_proc):
170 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
171 | processes += [p]
172 |
173 | # start processes
174 | print(f"Start prefetching...")
175 | import time
176 |
177 | start = time.time()
178 | gather_res = [[] for _ in range(n_proc)]
179 | try:
180 | for p in processes:
181 | p.start()
182 |
183 | k = 0
184 | while k < n_proc:
185 | # get result
186 | res = Q.get()
187 | if res == "Done":
188 | k += 1
189 | else:
190 | gather_res[res[0]] = res[1]
191 |
192 | except Exception as e:
193 | print("Exception: ", e)
194 | for p in processes:
195 | p.terminate()
196 |
197 | raise e
198 | finally:
199 | for p in processes:
200 | p.join()
201 | print(f"Prefetching complete. [{time.time() - start} sec.]")
202 |
203 | if target_data_type == 'ndarray':
204 | if not isinstance(gather_res[0], np.ndarray):
205 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
206 |
207 | # order outputs
208 | return np.concatenate(gather_res, axis=0)
209 | elif target_data_type == 'list':
210 | out = []
211 | for r in gather_res:
212 | out.extend(r)
213 | return out
214 | else:
215 | return gather_res
216 |
--------------------------------------------------------------------------------